mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
dev: Everything is Assets
This commit is contained in:
parent
c708d0a433
commit
f92307cd4c
158
alembic_db/versions/0001_assets.py
Normal file
158
alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
"""initial assets schema + per-asset state cache
|
||||||
|
|
||||||
|
Revision ID: 0001_assets
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-08-20 00:00:00
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "0001_assets"
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ASSETS: content identity (deduplicated by hash)
|
||||||
|
op.create_table(
|
||||||
|
"assets",
|
||||||
|
sa.Column("hash", sa.String(length=128), primary_key=True),
|
||||||
|
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
|
||||||
|
sa.Column("storage_locator", sa.Text(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
|
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||||
|
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||||
|
op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"])
|
||||||
|
|
||||||
|
# ASSETS_INFO: user-visible references (mutable metadata)
|
||||||
|
op.create_table(
|
||||||
|
"assets_info",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("owner_id", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("name", sa.String(length=512), nullable=False),
|
||||||
|
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
|
||||||
|
sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
|
||||||
|
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
|
sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
|
sqlite_autoincrement=True,
|
||||||
|
)
|
||||||
|
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||||
|
op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"])
|
||||||
|
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||||
|
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||||
|
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||||
|
|
||||||
|
# TAGS: normalized tag vocabulary
|
||||||
|
op.create_table(
|
||||||
|
"tags",
|
||||||
|
sa.Column("name", sa.String(length=128), primary_key=True),
|
||||||
|
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||||
|
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||||
|
|
||||||
|
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||||
|
op.create_table(
|
||||||
|
"asset_info_tags",
|
||||||
|
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||||
|
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||||
|
sa.Column("added_by", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
|
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||||
|
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||||
|
|
||||||
|
# ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records
|
||||||
|
op.create_table(
|
||||||
|
"asset_locator_state",
|
||||||
|
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
|
||||||
|
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("etag", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||||
|
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||||
|
op.create_table(
|
||||||
|
"asset_info_meta",
|
||||||
|
sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("key", sa.String(length=256), nullable=False),
|
||||||
|
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||||
|
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||||
|
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||||
|
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||||
|
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||||
|
|
||||||
|
# Tags vocabulary for models
|
||||||
|
tags_table = sa.table(
|
||||||
|
"tags",
|
||||||
|
sa.column("name", sa.String()),
|
||||||
|
sa.column("tag_type", sa.String()),
|
||||||
|
)
|
||||||
|
op.bulk_insert(
|
||||||
|
tags_table,
|
||||||
|
[
|
||||||
|
# Core concept tags
|
||||||
|
{"name": "models", "tag_type": "system"},
|
||||||
|
|
||||||
|
# Canonical single-word types
|
||||||
|
{"name": "checkpoint", "tag_type": "system"},
|
||||||
|
{"name": "lora", "tag_type": "system"},
|
||||||
|
{"name": "vae", "tag_type": "system"},
|
||||||
|
{"name": "text-encoder", "tag_type": "system"},
|
||||||
|
{"name": "clip-vision", "tag_type": "system"},
|
||||||
|
{"name": "embedding", "tag_type": "system"},
|
||||||
|
{"name": "controlnet", "tag_type": "system"},
|
||||||
|
{"name": "upscale", "tag_type": "system"},
|
||||||
|
{"name": "diffusion-model", "tag_type": "system"},
|
||||||
|
{"name": "hypernetwork", "tag_type": "system"},
|
||||||
|
{"name": "vae_approx", "tag_type": "system"},
|
||||||
|
# TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||||
|
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||||
|
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||||
|
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||||
|
op.drop_table("asset_info_meta")
|
||||||
|
|
||||||
|
op.drop_table("asset_locator_state")
|
||||||
|
|
||||||
|
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||||
|
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||||
|
op.drop_table("asset_info_tags")
|
||||||
|
|
||||||
|
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||||
|
op.drop_table("tags")
|
||||||
|
|
||||||
|
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_asset_hash", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||||
|
op.drop_table("assets_info")
|
||||||
|
|
||||||
|
op.drop_index("ix_assets_backend_locator", table_name="assets")
|
||||||
|
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||||
|
op.drop_table("assets")
|
||||||
@ -1,40 +0,0 @@
|
|||||||
"""init
|
|
||||||
|
|
||||||
Revision ID: e9c714da8d57
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-05-30 20:14:33.772039
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'e9c714da8d57'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
op.create_table('model',
|
|
||||||
sa.Column('type', sa.Text(), nullable=False),
|
|
||||||
sa.Column('path', sa.Text(), nullable=False),
|
|
||||||
sa.Column('file_name', sa.Text(), nullable=True),
|
|
||||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
|
||||||
sa.Column('hash', sa.Text(), nullable=True),
|
|
||||||
sa.Column('hash_algorithm', sa.Text(), nullable=True),
|
|
||||||
sa.Column('source_url', sa.Text(), nullable=True),
|
|
||||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('type', 'path')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('model')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
110
app/api/assets_routes.py
Normal file
110
app/api/assets_routes.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import json
|
||||||
|
from typing import Sequence
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from app import assets_manager
|
||||||
|
|
||||||
|
|
||||||
|
ROUTES = web.RouteTableDef()
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/assets")
|
||||||
|
async def list_assets(request: web.Request) -> web.Response:
|
||||||
|
q = request.rel_url.query
|
||||||
|
|
||||||
|
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
|
||||||
|
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
|
||||||
|
name_contains = q.get("name_contains")
|
||||||
|
|
||||||
|
# Optional JSON metadata filter (top-level key equality only for now)
|
||||||
|
metadata_filter = None
|
||||||
|
raw_meta = q.get("metadata_filter")
|
||||||
|
if raw_meta:
|
||||||
|
try:
|
||||||
|
metadata_filter = json.loads(raw_meta)
|
||||||
|
if not isinstance(metadata_filter, dict):
|
||||||
|
metadata_filter = None
|
||||||
|
except Exception:
|
||||||
|
# Silently ignore malformed JSON for first iteration; could 400 in future
|
||||||
|
metadata_filter = None
|
||||||
|
|
||||||
|
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
|
||||||
|
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
|
||||||
|
sort = q.get("sort", "created_at")
|
||||||
|
order = q.get("order", "desc")
|
||||||
|
|
||||||
|
payload = await assets_manager.list_assets(
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
name_contains=name_contains,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
sort=sort,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
return web.json_response(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.put("/api/assets/{id}")
|
||||||
|
async def update_asset(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id_raw = request.match_info.get("id")
|
||||||
|
try:
|
||||||
|
asset_info_id = int(asset_info_id_raw)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
name = payload.get("name", None)
|
||||||
|
tags = payload.get("tags", None)
|
||||||
|
user_metadata = payload.get("user_metadata", None)
|
||||||
|
|
||||||
|
if name is None and tags is None and user_metadata is None:
|
||||||
|
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
|
||||||
|
|
||||||
|
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
|
||||||
|
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
|
||||||
|
|
||||||
|
if user_metadata is not None and not isinstance(user_metadata, dict):
|
||||||
|
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await assets_manager.update_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
name=name,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=user_metadata,
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
return web.json_response(result, status=200)
|
||||||
|
|
||||||
|
|
||||||
|
def register_assets_routes(app: web.Application) -> None:
|
||||||
|
app.add_routes(ROUTES)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_csv_tags(raw: str | None) -> list[str]:
|
||||||
|
if not raw:
|
||||||
|
return []
|
||||||
|
return [t.strip() for t in raw.split(",") if t.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
|
||||||
|
if not qval:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
v = int(qval)
|
||||||
|
except Exception:
|
||||||
|
return default
|
||||||
|
return max(lo, min(hi, v))
|
||||||
|
|
||||||
|
|
||||||
|
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||||
|
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||||
148
app/assets_manager.py
Normal file
148
app/assets_manager.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from .database.db import create_session
|
||||||
|
from .storage import hashing
|
||||||
|
from .database.services import (
|
||||||
|
check_fs_asset_exists_quick,
|
||||||
|
ingest_fs_asset,
|
||||||
|
touch_asset_infos_by_fs_path,
|
||||||
|
list_asset_infos_page,
|
||||||
|
update_asset_info_full,
|
||||||
|
get_asset_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||||
|
st = os.stat(path, follow_symlinks=True)
|
||||||
|
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
|
||||||
|
|
||||||
|
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||||
|
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Uses absolute path as the canonical locator for the 'fs' backend.
|
||||||
|
- Computes BLAKE3 only when the fast existence check indicates it's needed.
|
||||||
|
- This function ensures the identity row and seeds mtime in asset_locator_state.
|
||||||
|
"""
|
||||||
|
abs_path = os.path.abspath(file_path)
|
||||||
|
size_bytes, mtime_ns = get_size_mtime_ns(abs_path)
|
||||||
|
if not size_bytes:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with await create_session() as session:
|
||||||
|
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||||
|
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc))
|
||||||
|
await session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
asset_hash = hashing.blake3_hash_sync(abs_path)
|
||||||
|
|
||||||
|
async with await create_session() as session:
|
||||||
|
await ingest_fs_asset(
|
||||||
|
session,
|
||||||
|
asset_hash="blake3:" + asset_hash,
|
||||||
|
abs_path=abs_path,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
mime_type=None,
|
||||||
|
info_name=file_name,
|
||||||
|
tag_origin="automatic",
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def list_assets(
|
||||||
|
*,
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: Optional[str] = None,
|
||||||
|
metadata_filter: Optional[dict] = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
sort: str | None = "created_at",
|
||||||
|
order: str | None = "desc",
|
||||||
|
) -> dict:
|
||||||
|
sort = _safe_sort_field(sort)
|
||||||
|
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||||
|
|
||||||
|
async with await create_session() as session:
|
||||||
|
infos, tag_map, total = await list_asset_infos_page(
|
||||||
|
session,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
name_contains=name_contains,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
sort=sort,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
|
||||||
|
assets_json = []
|
||||||
|
for info in infos:
|
||||||
|
asset = info.asset # populated via contains_eager
|
||||||
|
tags = tag_map.get(info.id, [])
|
||||||
|
assets_json.append(
|
||||||
|
{
|
||||||
|
"id": info.id,
|
||||||
|
"name": info.name,
|
||||||
|
"asset_hash": info.asset_hash,
|
||||||
|
"size": int(asset.size_bytes) if asset else None,
|
||||||
|
"mime_type": asset.mime_type if asset else None,
|
||||||
|
"tags": tags,
|
||||||
|
"preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
|
||||||
|
"created_at": info.created_at.isoformat() if info.created_at else None,
|
||||||
|
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
||||||
|
"last_access_time": info.last_access_time.isoformat() if info.last_access_time else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"assets": assets_json,
|
||||||
|
"total": total,
|
||||||
|
"has_more": (offset + len(assets_json)) < total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def update_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
name: str | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
async with await create_session() as session:
|
||||||
|
info = await update_asset_info_full(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
name=name,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=user_metadata,
|
||||||
|
tag_origin="manual",
|
||||||
|
added_by=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": info.id,
|
||||||
|
"name": info.name,
|
||||||
|
"asset_hash": info.asset_hash,
|
||||||
|
"tags": tag_names,
|
||||||
|
"user_metadata": info.user_metadata or {},
|
||||||
|
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_sort_field(requested: str | None) -> str:
|
||||||
|
if not requested:
|
||||||
|
return "created_at"
|
||||||
|
v = requested.lower()
|
||||||
|
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||||
|
return v
|
||||||
|
return "created_at"
|
||||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
@ -1,112 +1,267 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from app.logger import log_startup_warning
|
from app.logger import log_startup_warning
|
||||||
from utils.install_util import get_missing_requirements_message
|
from utils.install_util import get_missing_requirements_message
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
_DB_AVAILABLE = False
|
|
||||||
Session = None
|
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Attempt imports which may not exist in some environments
|
||||||
try:
|
try:
|
||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from alembic.runtime.migration import MigrationContext
|
from alembic.runtime.migration import MigrationContext
|
||||||
from alembic.script import ScriptDirectory
|
from alembic.script import ScriptDirectory
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, text
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
_DB_AVAILABLE = True
|
_DB_AVAILABLE = True
|
||||||
|
ENGINE: AsyncEngine | None = None
|
||||||
|
SESSION: async_sessionmaker | None = None
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
log_startup_warning(
|
log_startup_warning(
|
||||||
f"""
|
(
|
||||||
------------------------------------------------------------------------
|
"------------------------------------------------------------------------\n"
|
||||||
Error importing dependencies: {e}
|
f"Error importing DB dependencies: {e}\n"
|
||||||
{get_missing_requirements_message()}
|
f"{get_missing_requirements_message()}\n"
|
||||||
This error is happening because ComfyUI now uses a local sqlite database.
|
"This error is happening because ComfyUI now uses a local database.\n"
|
||||||
------------------------------------------------------------------------
|
"------------------------------------------------------------------------"
|
||||||
""".strip()
|
).strip()
|
||||||
)
|
)
|
||||||
|
_DB_AVAILABLE = False
|
||||||
|
ENGINE = None
|
||||||
|
SESSION = None
|
||||||
|
|
||||||
|
|
||||||
def dependencies_available():
|
def dependencies_available() -> bool:
|
||||||
"""
|
"""Check if DB dependencies are importable."""
|
||||||
Temporary function to check if the dependencies are available
|
|
||||||
"""
|
|
||||||
return _DB_AVAILABLE
|
return _DB_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def can_create_session():
|
def _root_paths():
|
||||||
"""
|
"""Resolve alembic.ini and migrations script folder."""
|
||||||
Temporary function to check if the database is available to create a session
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
|
||||||
"""
|
|
||||||
return dependencies_available() and Session is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_alembic_config():
|
|
||||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
|
||||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||||
|
return config_path, scripts_path
|
||||||
config = Config(config_path)
|
|
||||||
config.set_main_option("script_location", scripts_path)
|
|
||||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_path():
|
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||||
url = args.database_url
|
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||||
if url.startswith("sqlite:///"):
|
try:
|
||||||
return url.split("///")[1]
|
u = make_url(db_url)
|
||||||
|
except Exception:
|
||||||
|
return db_url
|
||||||
|
|
||||||
|
if not u.drivername.startswith("sqlite"):
|
||||||
|
return db_url
|
||||||
|
|
||||||
|
# Make path absolute if relative
|
||||||
|
db_path = u.database or ""
|
||||||
|
if not os.path.isabs(db_path):
|
||||||
|
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||||
|
u = u.set(database=db_path)
|
||||||
|
return str(u)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_sync_driver_url(async_url: str) -> str:
|
||||||
|
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||||
|
u = make_url(async_url)
|
||||||
|
driver = u.drivername
|
||||||
|
|
||||||
|
if driver.startswith("sqlite+aiosqlite"):
|
||||||
|
u = u.set(drivername="sqlite")
|
||||||
|
elif driver.startswith("postgresql+asyncpg"):
|
||||||
|
u = u.set(drivername="postgresql")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
# Generic: strip the async driver part if present
|
||||||
|
if "+" in driver:
|
||||||
|
u = u.set(drivername=driver.split("+", 1)[0])
|
||||||
|
|
||||||
|
return str(u)
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||||
db_url = args.database_url
|
"""Return the on-disk path for a SQLite URL, else None."""
|
||||||
logging.debug(f"Database URL: {db_url}")
|
try:
|
||||||
db_path = get_db_path()
|
u = make_url(sync_url)
|
||||||
db_exists = os.path.exists(db_path)
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
config = get_alembic_config()
|
if not u.drivername.startswith("sqlite"):
|
||||||
|
return None
|
||||||
|
return u.database
|
||||||
|
|
||||||
# Check if we need to upgrade
|
|
||||||
engine = create_engine(db_url)
|
|
||||||
conn = engine.connect()
|
|
||||||
|
|
||||||
context = MigrationContext.configure(conn)
|
def _get_alembic_config(sync_url: str) -> Config:
|
||||||
current_rev = context.get_current_revision()
|
"""Prepare Alembic Config with script location and DB URL."""
|
||||||
|
config_path, scripts_path = _root_paths()
|
||||||
|
cfg = Config(config_path)
|
||||||
|
cfg.set_main_option("script_location", scripts_path)
|
||||||
|
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||||
|
return cfg
|
||||||
|
|
||||||
script = ScriptDirectory.from_config(config)
|
|
||||||
|
async def init_db_engine() -> None:
|
||||||
|
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||||
|
|
||||||
|
This must be called once on application startup before any DB usage.
|
||||||
|
"""
|
||||||
|
global ENGINE, SESSION
|
||||||
|
|
||||||
|
if not dependencies_available():
|
||||||
|
raise RuntimeError("Database dependencies are not available.")
|
||||||
|
|
||||||
|
if ENGINE is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_url = args.database_url
|
||||||
|
if not raw_url:
|
||||||
|
raise RuntimeError("Database URL is not configured.")
|
||||||
|
|
||||||
|
# Absolutize SQLite path for async engine
|
||||||
|
db_url = _absolutize_sqlite_url(raw_url)
|
||||||
|
|
||||||
|
# Prepare async engine
|
||||||
|
connect_args = {}
|
||||||
|
if db_url.startswith("sqlite"):
|
||||||
|
connect_args = {
|
||||||
|
"check_same_thread": False,
|
||||||
|
"timeout": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
ENGINE = create_async_engine(
|
||||||
|
db_url,
|
||||||
|
connect_args=connect_args,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enforce SQLite pragmas on the async engine
|
||||||
|
if db_url.startswith("sqlite"):
|
||||||
|
async with ENGINE.begin() as conn:
|
||||||
|
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||||
|
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||||
|
if str(current_mode).lower() != "wal":
|
||||||
|
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||||
|
if str(new_mode).lower() != "wal":
|
||||||
|
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||||
|
LOGGER.info("SQLite journal mode set to WAL.")
|
||||||
|
|
||||||
|
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||||
|
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||||
|
|
||||||
|
await _run_migrations(raw_url=db_url)
|
||||||
|
|
||||||
|
SESSION = async_sessionmaker(
|
||||||
|
bind=ENGINE,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_migrations(raw_url: str) -> None:
|
||||||
|
"""
|
||||||
|
Run Alembic migrations up to head.
|
||||||
|
|
||||||
|
We deliberately use a synchronous engine for migrations because Alembic's
|
||||||
|
programmatic API is synchronous by default and this path is robust.
|
||||||
|
"""
|
||||||
|
# Convert to sync URL and make SQLite URL an absolute one
|
||||||
|
sync_url = _to_sync_driver_url(raw_url)
|
||||||
|
sync_url = _absolutize_sqlite_url(sync_url)
|
||||||
|
|
||||||
|
cfg = _get_alembic_config(sync_url)
|
||||||
|
|
||||||
|
# Inspect current and target heads
|
||||||
|
engine = create_engine(sync_url, future=True)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
context = MigrationContext.configure(conn)
|
||||||
|
current_rev = context.get_current_revision()
|
||||||
|
|
||||||
|
script = ScriptDirectory.from_config(cfg)
|
||||||
target_rev = script.get_current_head()
|
target_rev = script.get_current_head()
|
||||||
|
|
||||||
if target_rev is None:
|
if target_rev is None:
|
||||||
logging.warning("No target revision found.")
|
LOGGER.warning("Alembic: no target revision found.")
|
||||||
elif current_rev != target_rev:
|
return
|
||||||
# Backup the database pre upgrade
|
|
||||||
backup_path = db_path + ".bkp"
|
|
||||||
if db_exists:
|
|
||||||
shutil.copy(db_path, backup_path)
|
|
||||||
else:
|
|
||||||
backup_path = None
|
|
||||||
|
|
||||||
|
if current_rev == target_rev:
|
||||||
|
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||||
|
return
|
||||||
|
|
||||||
|
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||||
|
|
||||||
|
# Optional backup for SQLite file DBs
|
||||||
|
backup_path = None
|
||||||
|
sqlite_path = _get_sqlite_file_path(sync_url)
|
||||||
|
if sqlite_path and os.path.exists(sqlite_path):
|
||||||
|
backup_path = sqlite_path + ".bkp"
|
||||||
try:
|
try:
|
||||||
command.upgrade(config, target_rev)
|
shutil.copy(sqlite_path, backup_path)
|
||||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
except Exception as exc:
|
||||||
except Exception as e:
|
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||||
if backup_path:
|
|
||||||
# Restore the database from backup if upgrade fails
|
try:
|
||||||
shutil.copy(backup_path, db_path)
|
command.upgrade(cfg, target_rev)
|
||||||
|
except Exception:
|
||||||
|
if backup_path and os.path.exists(backup_path):
|
||||||
|
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||||
|
try:
|
||||||
|
shutil.copy(backup_path, sqlite_path) # restore
|
||||||
os.remove(backup_path)
|
os.remove(backup_path)
|
||||||
logging.exception("Error upgrading database: ")
|
except Exception as re:
|
||||||
raise e
|
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||||
|
else:
|
||||||
global Session
|
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||||
Session = sessionmaker(bind=engine)
|
raise
|
||||||
|
|
||||||
|
|
||||||
def create_session():
|
def get_engine():
|
||||||
return Session()
|
"""Return the global async engine (initialized after init_db_engine())."""
|
||||||
|
if ENGINE is None:
|
||||||
|
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||||
|
return ENGINE
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_maker():
|
||||||
|
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||||
|
if SESSION is None:
|
||||||
|
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||||
|
return SESSION
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session_scope():
|
||||||
|
"""Async context manager for a unit of work:
|
||||||
|
|
||||||
|
async with session_scope() as sess:
|
||||||
|
... use sess ...
|
||||||
|
"""
|
||||||
|
maker = get_session_maker()
|
||||||
|
async with maker() as sess:
|
||||||
|
try:
|
||||||
|
yield sess
|
||||||
|
await sess.commit()
|
||||||
|
except Exception:
|
||||||
|
await sess.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def create_session():
|
||||||
|
"""Convenience helper to acquire a single AsyncSession instance.
|
||||||
|
|
||||||
|
Typical usage:
|
||||||
|
async with (await create_session()) as sess:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
maker = get_session_maker()
|
||||||
|
return maker()
|
||||||
|
|||||||
@ -1,59 +1,257 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
|
||||||
Integer,
|
Integer,
|
||||||
Text,
|
BigInteger,
|
||||||
DateTime,
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
JSON,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
CheckConstraint,
|
||||||
|
Numeric,
|
||||||
|
Boolean,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
def to_dict(obj):
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||||
fields = obj.__table__.columns.keys()
|
fields = obj.__table__.columns.keys()
|
||||||
return {
|
out: dict[str, Any] = {}
|
||||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
for field in fields:
|
||||||
for field in fields
|
val = getattr(obj, field)
|
||||||
if (val := getattr(obj, field))
|
if val is None and not include_none:
|
||||||
}
|
continue
|
||||||
|
if isinstance(val, datetime):
|
||||||
|
out[field] = val.isoformat()
|
||||||
|
else:
|
||||||
|
out[field] = val
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Model(Base):
|
class Asset(Base):
|
||||||
"""
|
__tablename__ = "assets"
|
||||||
sqlalchemy model representing a model file in the system.
|
|
||||||
|
|
||||||
This class defines the database schema for storing information about model files,
|
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||||
including their type, path, hash, and when they were added to the system.
|
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||||
|
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||||
|
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
|
||||||
|
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
Attributes:
|
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
"AssetInfo",
|
||||||
path (Text): The file path of the model relative to the type folder (primary key)
|
back_populates="asset",
|
||||||
file_name (Text): The name of the model file
|
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
|
||||||
file_size (Integer): The size of the model file in bytes
|
foreign_keys=lambda: [AssetInfo.asset_hash],
|
||||||
hash (Text): A hash of the model file
|
cascade="all,delete-orphan",
|
||||||
hash_algorithm (Text): The algorithm used to generate the hash
|
passive_deletes=True,
|
||||||
source_url (Text): The URL of the model file
|
)
|
||||||
date_added (DateTime): Timestamp of when the model was added to the system
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "model"
|
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||||
|
"AssetInfo",
|
||||||
|
back_populates="preview_asset",
|
||||||
|
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
|
||||||
|
foreign_keys=lambda: [AssetInfo.preview_hash],
|
||||||
|
viewonly=True,
|
||||||
|
)
|
||||||
|
|
||||||
type = Column(Text, primary_key=True)
|
locator_state: Mapped["AssetLocatorState | None"] = relationship(
|
||||||
path = Column(Text, primary_key=True)
|
back_populates="asset",
|
||||||
file_name = Column(Text)
|
uselist=False,
|
||||||
file_size = Column(Integer)
|
cascade="all, delete-orphan",
|
||||||
hash = Column(Text)
|
passive_deletes=True,
|
||||||
hash_algorithm = Column(Text)
|
)
|
||||||
source_url = Column(Text)
|
|
||||||
date_added = Column(DateTime, server_default=func.now())
|
|
||||||
|
|
||||||
def to_dict(self):
|
__table_args__ = (
|
||||||
"""
|
Index("ix_assets_mime_type", "mime_type"),
|
||||||
Convert the model instance to a dictionary representation.
|
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
|
||||||
|
)
|
||||||
|
|
||||||
Returns:
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
dict: A dictionary containing the attributes of the model
|
return to_dict(self, include_none=include_none)
|
||||||
"""
|
|
||||||
dict = to_dict(self)
|
def __repr__(self) -> str:
|
||||||
return dict
|
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
|
||||||
|
|
||||||
|
|
||||||
|
class AssetLocatorState(Base):
|
||||||
|
__tablename__ = "asset_locator_state"
|
||||||
|
|
||||||
|
asset_hash: Mapped[str] = mapped_column(
|
||||||
|
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
# For fs backends: nanosecond mtime; nullable if not applicable
|
||||||
|
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||||
|
# For HTTP/S3/GCS/Azure, etc.: optional validators
|
||||||
|
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||||
|
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
|
|
||||||
|
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
|
return to_dict(self, include_none=include_none)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
|
||||||
|
|
||||||
|
|
||||||
|
class AssetInfo(Base):
|
||||||
|
__tablename__ = "assets_info"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
owner_id: Mapped[str | None] = mapped_column(String(128))
|
||||||
|
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
asset_hash: Mapped[str] = mapped_column(
|
||||||
|
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
|
||||||
|
)
|
||||||
|
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
||||||
|
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||||
|
)
|
||||||
|
last_access_time: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
asset: Mapped[Asset] = relationship(
|
||||||
|
"Asset",
|
||||||
|
back_populates="infos",
|
||||||
|
foreign_keys=[asset_hash],
|
||||||
|
)
|
||||||
|
preview_asset: Mapped[Asset | None] = relationship(
|
||||||
|
"Asset",
|
||||||
|
back_populates="preview_of",
|
||||||
|
foreign_keys=[preview_hash],
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||||
|
back_populates="asset_info",
|
||||||
|
cascade="all,delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||||
|
back_populates="asset_info",
|
||||||
|
cascade="all,delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
overlaps="tags,asset_infos",
|
||||||
|
)
|
||||||
|
|
||||||
|
tags: Mapped[list["Tag"]] = relationship(
|
||||||
|
secondary="asset_info_tags",
|
||||||
|
back_populates="asset_infos",
|
||||||
|
lazy="joined",
|
||||||
|
viewonly=True,
|
||||||
|
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_assets_info_owner_id", "owner_id"),
|
||||||
|
Index("ix_assets_info_asset_hash", "asset_hash"),
|
||||||
|
Index("ix_assets_info_name", "name"),
|
||||||
|
Index("ix_assets_info_created_at", "created_at"),
|
||||||
|
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||||
|
{"sqlite_autoincrement": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
|
data = to_dict(self, include_none=include_none)
|
||||||
|
data["tags"] = [t.name for t in self.tags]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AssetInfoMeta(Base):
|
||||||
|
__tablename__ = "asset_info_meta"
|
||||||
|
|
||||||
|
asset_info_id: Mapped[int] = mapped_column(
|
||||||
|
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||||
|
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||||
|
|
||||||
|
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||||
|
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||||
|
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||||
|
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
|
||||||
|
|
||||||
|
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_asset_info_meta_key", "key"),
|
||||||
|
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||||
|
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||||
|
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssetInfoTag(Base):
|
||||||
|
__tablename__ = "asset_info_tags"
|
||||||
|
|
||||||
|
asset_info_id: Mapped[int] = mapped_column(
|
||||||
|
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
tag_name: Mapped[str] = mapped_column(
|
||||||
|
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||||
|
)
|
||||||
|
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||||
|
added_by: Mapped[str | None] = mapped_column(String(128))
|
||||||
|
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||||
|
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||||
|
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||||
|
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||||
|
|
||||||
|
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||||
|
back_populates="tag",
|
||||||
|
overlaps="asset_infos,tags",
|
||||||
|
)
|
||||||
|
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||||
|
secondary="asset_info_tags",
|
||||||
|
back_populates="tags",
|
||||||
|
viewonly=True,
|
||||||
|
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_tags_tag_type", "tag_type"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Tag {self.name}>"
|
||||||
|
|||||||
683
app/database/services.py
Normal file
683
app/database/services.py
Normal file
@ -0,0 +1,683 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any, Sequence, Optional, Iterable
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy import select, delete, exists, func
|
||||||
|
from sqlalchemy.orm import contains_eager
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||||
|
|
||||||
|
|
||||||
|
async def check_fs_asset_exists_quick(
|
||||||
|
session,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
size_bytes: Optional[int] = None,
|
||||||
|
mtime_ns: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
||||||
|
AND (if provided) mtime_ns matches stored locator-state,
|
||||||
|
AND (if provided) size_bytes matches verified size when known.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
stmt = select(sa.literal(True)).select_from(Asset)
|
||||||
|
|
||||||
|
conditions = [
|
||||||
|
Asset.storage_backend == "fs",
|
||||||
|
Asset.storage_locator == locator,
|
||||||
|
]
|
||||||
|
|
||||||
|
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
||||||
|
# If verified size is 0 (unknown), we don't force equality.
|
||||||
|
if size_bytes is not None:
|
||||||
|
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||||
|
|
||||||
|
# If mtime_ns provided require the locator-state to exist and match.
|
||||||
|
if mtime_ns is not None:
|
||||||
|
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
||||||
|
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
||||||
|
|
||||||
|
stmt = stmt.where(*conditions).limit(1)
|
||||||
|
|
||||||
|
row = (await session.execute(stmt)).first()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def ingest_fs_asset(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
abs_path: str,
|
||||||
|
size_bytes: int,
|
||||||
|
mtime_ns: int,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
info_name: Optional[str] = None,
|
||||||
|
owner_id: Optional[str] = None,
|
||||||
|
preview_hash: Optional[str] = None,
|
||||||
|
user_metadata: Optional[dict] = None,
|
||||||
|
tags: Sequence[str] = (),
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
added_by: Optional[str] = None,
|
||||||
|
require_existing_tags: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Creates or updates Asset record for a local (fs) asset.
|
||||||
|
|
||||||
|
Always:
|
||||||
|
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
||||||
|
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
||||||
|
|
||||||
|
Optionally (when info_name is provided):
|
||||||
|
- Create an AssetInfo (no refcount changes).
|
||||||
|
- Link provided tags to that AssetInfo.
|
||||||
|
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
||||||
|
* If False (default), silently skips unknown tags.
|
||||||
|
|
||||||
|
Returns flags and ids:
|
||||||
|
{
|
||||||
|
"asset_created": bool,
|
||||||
|
"asset_updated": bool,
|
||||||
|
"state_created": bool,
|
||||||
|
"state_updated": bool,
|
||||||
|
"asset_info_id": int | None,
|
||||||
|
"tags_added": list[str],
|
||||||
|
"tags_missing": list[str], # filled only when require_existing_tags=False
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
datetime_now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
out = {
|
||||||
|
"asset_created": False,
|
||||||
|
"asset_updated": False,
|
||||||
|
"state_created": False,
|
||||||
|
"state_updated": False,
|
||||||
|
"asset_info_id": None,
|
||||||
|
"tags_added": [],
|
||||||
|
"tags_missing": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
|
||||||
|
async with session.begin_nested() as sp1:
|
||||||
|
try:
|
||||||
|
session.add(
|
||||||
|
Asset(
|
||||||
|
hash=asset_hash,
|
||||||
|
size_bytes=int(size_bytes),
|
||||||
|
mime_type=mime_type,
|
||||||
|
refcount=0,
|
||||||
|
storage_backend="fs",
|
||||||
|
storage_locator=locator,
|
||||||
|
created_at=datetime_now,
|
||||||
|
updated_at=datetime_now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
out["asset_created"] = True
|
||||||
|
except IntegrityError:
|
||||||
|
await sp1.rollback()
|
||||||
|
# Already exists by hash -> update selected fields if different
|
||||||
|
existing = await session.get(Asset, asset_hash)
|
||||||
|
if existing is not None:
|
||||||
|
desired_size = int(size_bytes)
|
||||||
|
if existing.size_bytes != desired_size:
|
||||||
|
existing.size_bytes = desired_size
|
||||||
|
existing.updated_at = datetime_now
|
||||||
|
out["asset_updated"] = True
|
||||||
|
else:
|
||||||
|
# This should not occur. Log for visibility.
|
||||||
|
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
|
||||||
|
except Exception:
|
||||||
|
await sp1.rollback()
|
||||||
|
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
||||||
|
async with session.begin_nested() as sp2:
|
||||||
|
try:
|
||||||
|
session.add(
|
||||||
|
AssetLocatorState(
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
mtime_ns=int(mtime_ns),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
out["state_created"] = True
|
||||||
|
except IntegrityError:
|
||||||
|
await sp2.rollback()
|
||||||
|
state = await session.get(AssetLocatorState, asset_hash)
|
||||||
|
if state is not None:
|
||||||
|
desired_mtime = int(mtime_ns)
|
||||||
|
if state.mtime_ns != desired_mtime:
|
||||||
|
state.mtime_ns = desired_mtime
|
||||||
|
out["state_updated"] = True
|
||||||
|
else:
|
||||||
|
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
||||||
|
except Exception:
|
||||||
|
await sp2.rollback()
|
||||||
|
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ---- Optional: AssetInfo + tag links ----
|
||||||
|
if info_name:
|
||||||
|
# 2a) Create AssetInfo (no refcount bump)
|
||||||
|
async with session.begin_nested() as sp3:
|
||||||
|
try:
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=info_name,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
preview_hash=preview_hash,
|
||||||
|
created_at=datetime_now,
|
||||||
|
updated_at=datetime_now,
|
||||||
|
last_access_time=datetime_now,
|
||||||
|
)
|
||||||
|
session.add(info)
|
||||||
|
await session.flush() # get info.id
|
||||||
|
out["asset_info_id"] = info.id
|
||||||
|
except Exception:
|
||||||
|
await sp3.rollback()
|
||||||
|
logging.exception(
|
||||||
|
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
|
||||||
|
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
if norm and out["asset_info_id"] is not None:
|
||||||
|
# Which tags exist?
|
||||||
|
existing_tag_names = set(
|
||||||
|
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||||
|
)
|
||||||
|
missing = [t for t in norm if t not in existing_tag_names]
|
||||||
|
if missing and require_existing_tags:
|
||||||
|
raise ValueError(f"Unknown tags: {missing}")
|
||||||
|
|
||||||
|
# Which links already exist?
|
||||||
|
existing_links = set(
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||||
|
if to_add:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
tag_name=t,
|
||||||
|
origin=tag_origin,
|
||||||
|
added_by=added_by,
|
||||||
|
added_at=datetime_now,
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
out["tags_added"] = to_add
|
||||||
|
out["tags_missing"] = missing
|
||||||
|
|
||||||
|
# 2c) Rebuild metadata projection if provided
|
||||||
|
if user_metadata is not None and out["asset_info_id"] is not None:
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
user_metadata=user_metadata,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def touch_asset_infos_by_fs_path(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
abs_path: str,
|
||||||
|
ts: Optional[datetime] = None,
|
||||||
|
only_if_newer: bool = True,
|
||||||
|
) -> int:
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
ts = ts or datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
stmt = sa.update(AssetInfo).where(
|
||||||
|
sa.exists(
|
||||||
|
sa.select(sa.literal(1))
|
||||||
|
.select_from(Asset)
|
||||||
|
.where(
|
||||||
|
Asset.hash == AssetInfo.asset_hash,
|
||||||
|
Asset.storage_backend == "fs",
|
||||||
|
Asset.storage_locator == locator,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if only_if_newer:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetInfo.last_access_time.is_(None),
|
||||||
|
AssetInfo.last_access_time < ts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = stmt.values(last_access_time=ts)
|
||||||
|
|
||||||
|
res = await session.execute(stmt)
|
||||||
|
return int(res.rowcount or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_asset_infos_page(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
sort: str = "created_at",
|
||||||
|
order: str = "desc",
|
||||||
|
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||||
|
"""
|
||||||
|
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
||||||
|
plus a map of asset_info_id -> [tags], and the total count.
|
||||||
|
|
||||||
|
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||||
|
"""
|
||||||
|
# Clamp
|
||||||
|
if limit <= 0:
|
||||||
|
limit = 1
|
||||||
|
if limit > 100:
|
||||||
|
limit = 100
|
||||||
|
if offset < 0:
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
# Build base query
|
||||||
|
base = (
|
||||||
|
select(AssetInfo)
|
||||||
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||||
|
.options(contains_eager(AssetInfo.asset))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filters
|
||||||
|
if name_contains:
|
||||||
|
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||||
|
|
||||||
|
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
|
|
||||||
|
base = _apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
sort = (sort or "created_at").lower()
|
||||||
|
order = (order or "desc").lower()
|
||||||
|
sort_map = {
|
||||||
|
"name": AssetInfo.name,
|
||||||
|
"created_at": AssetInfo.created_at,
|
||||||
|
"updated_at": AssetInfo.updated_at,
|
||||||
|
"last_access_time": AssetInfo.last_access_time,
|
||||||
|
"size": Asset.size_bytes,
|
||||||
|
}
|
||||||
|
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||||
|
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||||
|
|
||||||
|
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||||
|
|
||||||
|
# Total count (same filters, no ordering/limit/offset)
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(AssetInfo)
|
||||||
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||||
|
)
|
||||||
|
if name_contains:
|
||||||
|
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||||
|
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
|
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
|
total = (await session.execute(count_stmt)).scalar_one()
|
||||||
|
|
||||||
|
# Fetch rows
|
||||||
|
infos = (await session.execute(base)).scalars().unique().all()
|
||||||
|
|
||||||
|
# Collect tags in bulk (single query)
|
||||||
|
id_list = [i.id for i in infos]
|
||||||
|
tag_map: dict[int, list[str]] = defaultdict(list)
|
||||||
|
if id_list:
|
||||||
|
rows = await session.execute(
|
||||||
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||||
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||||
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||||
|
)
|
||||||
|
for aid, tag_name in rows.all():
|
||||||
|
tag_map[aid].append(tag_name)
|
||||||
|
|
||||||
|
return infos, tag_map, total
|
||||||
|
|
||||||
|
|
||||||
|
async def set_asset_info_tags(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
added_by: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Replace the tag set on an AssetInfo with `tags`. Idempotent.
|
||||||
|
Creates missing tag names as 'user'.
|
||||||
|
"""
|
||||||
|
desired = _normalize_tags(tags)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# current links
|
||||||
|
current = set(
|
||||||
|
tag_name for (tag_name,) in (
|
||||||
|
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
to_add = [t for t in desired if t not in current]
|
||||||
|
to_remove = [t for t in current if t not in desired]
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
await _ensure_tags_exist(session, to_add, tag_type="user")
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
|
||||||
|
for t in to_add
|
||||||
|
])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
await session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||||
|
|
||||||
|
|
||||||
|
async def update_asset_info_full(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
tags: Optional[Sequence[str]] = None,
|
||||||
|
user_metadata: Optional[dict] = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
added_by: Optional[str] = None,
|
||||||
|
) -> AssetInfo:
|
||||||
|
"""
|
||||||
|
Update AssetInfo fields:
|
||||||
|
- name (if provided)
|
||||||
|
- user_metadata blob + rebuild projection (if provided)
|
||||||
|
- replace tags with provided set (if provided)
|
||||||
|
Returns the updated AssetInfo.
|
||||||
|
"""
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
touched = False
|
||||||
|
if name is not None and name != info.name:
|
||||||
|
info.name = name
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if user_metadata is not None:
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=user_metadata
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
await set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
added_by=added_by,
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if touched and user_metadata is None:
|
||||||
|
info.updated_at = datetime.now(timezone.utc)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
async def replace_asset_info_metadata_projection(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
user_metadata: dict | None,
|
||||||
|
) -> None:
|
||||||
|
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
info.user_metadata = user_metadata or {}
|
||||||
|
info.updated_at = datetime.now(timezone.utc)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
if not user_metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetInfoMeta] = []
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
for r in _project_kv(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetInfoMeta(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
|
||||||
|
return [
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||||
|
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||||
|
wanted = _normalize_tags(list(names))
|
||||||
|
if not wanted:
|
||||||
|
return []
|
||||||
|
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||||
|
by_name = {t.name: t for t in existing}
|
||||||
|
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||||
|
if to_create:
|
||||||
|
session.add_all(to_create)
|
||||||
|
await session.flush()
|
||||||
|
by_name.update({t.name: t for t in to_create})
|
||||||
|
return [by_name[n] for n in wanted]
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_tag_filters(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
include_tags: Sequence[str] | None,
|
||||||
|
exclude_tags: Sequence[str] | None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||||
|
include_tags = _normalize_tags(include_tags)
|
||||||
|
exclude_tags = _normalize_tags(exclude_tags)
|
||||||
|
|
||||||
|
if include_tags:
|
||||||
|
for tag_name in include_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists().where(
|
||||||
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||||
|
& (AssetInfoTag.tag_name == tag_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists().where(
|
||||||
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||||
|
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _apply_metadata_filter(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
metadata_filter: dict | None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""Apply metadata filters using the projection table asset_info_meta.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
||||||
|
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
||||||
|
- For list values: ANY-of the list elements matches (EXISTS for any).
|
||||||
|
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
||||||
|
"""
|
||||||
|
if not metadata_filter:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||||
|
subquery = (
|
||||||
|
select(sa.literal(1))
|
||||||
|
.select_from(AssetInfoMeta)
|
||||||
|
.where(
|
||||||
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||||
|
AssetInfoMeta.key == key,
|
||||||
|
*preds,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
return sa.exists(subquery)
|
||||||
|
|
||||||
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||||
|
# Missing OR null:
|
||||||
|
if value is None:
|
||||||
|
# either: no row for key OR a row for key with explicit null
|
||||||
|
no_row_for_key = ~sa.exists(
|
||||||
|
select(sa.literal(1))
|
||||||
|
.select_from(AssetInfoMeta)
|
||||||
|
.where(
|
||||||
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||||
|
AssetInfoMeta.key == key,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
|
||||||
|
return sa.or_(no_row_for_key, null_row)
|
||||||
|
|
||||||
|
# Typed scalar matches:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||||
|
if isinstance(value, (int, float, Decimal)):
|
||||||
|
# store as Decimal for equality against NUMERIC(38,10)
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||||
|
|
||||||
|
# Complex: compare JSON (no index, but supported)
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||||
|
|
||||||
|
for k, v in metadata_filter.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
# ANY-of (exists for any element)
|
||||||
|
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
||||||
|
if ors:
|
||||||
|
stmt = stmt.where(sa.or_(*ors))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def _is_scalar(v: Any) -> bool:
|
||||||
|
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
||||||
|
return True
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return True
|
||||||
|
if isinstance(v, (int, float, Decimal, str)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _project_kv(key: str, value: Any) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Turn a metadata key/value into one or more projection rows:
|
||||||
|
- scalar -> one row (ordinal=0) in the proper typed column
|
||||||
|
- list of scalars -> one row per element with ordinal=i
|
||||||
|
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
||||||
|
- None -> single row with val_json = None
|
||||||
|
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
||||||
|
"""
|
||||||
|
rows: list[dict] = []
|
||||||
|
|
||||||
|
# None
|
||||||
|
if value is None:
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
# Scalars
|
||||||
|
if _is_scalar(value):
|
||||||
|
if isinstance(value, bool):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||||
|
elif isinstance(value, (int, float, Decimal)):
|
||||||
|
# store numeric; SQLAlchemy will coerce to Numeric
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
||||||
|
elif isinstance(value, str):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||||
|
else:
|
||||||
|
# Fallback to json
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
# Lists
|
||||||
|
if isinstance(value, list):
|
||||||
|
# list of scalars?
|
||||||
|
if all(_is_scalar(x) for x in value):
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
if x is None:
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": None})
|
||||||
|
elif isinstance(x, bool):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||||
|
elif isinstance(x, (int, float, Decimal)):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_num": x})
|
||||||
|
elif isinstance(x, str):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
# list contains objects -> one val_json per element
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
# Dict or any other structure -> single json row
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
@ -1,331 +0,0 @@
|
|||||||
import os
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from tqdm import tqdm
|
|
||||||
from folder_paths import get_relative_path, get_full_path
|
|
||||||
from app.database.db import create_session, dependencies_available, can_create_session
|
|
||||||
import blake3
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
|
|
||||||
if dependencies_available():
|
|
||||||
from app.database.models import Model
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProcessor:
|
|
||||||
def _validate_path(self, model_path):
|
|
||||||
try:
|
|
||||||
if not self._file_exists(model_path):
|
|
||||||
logging.error(f"Model file not found: {model_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = get_relative_path(model_path)
|
|
||||||
if not result:
|
|
||||||
logging.error(
|
|
||||||
f"Model file not in a recognized model directory: {model_path}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _file_exists(self, path):
|
|
||||||
"""Check if a file exists."""
|
|
||||||
return os.path.exists(path)
|
|
||||||
|
|
||||||
def _get_file_size(self, path):
|
|
||||||
"""Get file size."""
|
|
||||||
return os.path.getsize(path)
|
|
||||||
|
|
||||||
def _get_hasher(self):
|
|
||||||
return blake3.blake3()
|
|
||||||
|
|
||||||
def _hash_file(self, model_path):
|
|
||||||
try:
|
|
||||||
hasher = self._get_hasher()
|
|
||||||
with open(model_path, "rb", buffering=0) as f:
|
|
||||||
b = bytearray(128 * 1024)
|
|
||||||
mv = memoryview(b)
|
|
||||||
while n := f.readinto(mv):
|
|
||||||
hasher.update(mv[:n])
|
|
||||||
return hasher.hexdigest()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
|
||||||
return (
|
|
||||||
session.query(Model)
|
|
||||||
.filter(Model.type == model_type)
|
|
||||||
.filter(Model.path == model_relative_path)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _ensure_source_url(self, session, model, source_url):
|
|
||||||
if model.source_url is None:
|
|
||||||
model.source_url = source_url
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def _update_database(
|
|
||||||
self,
|
|
||||||
session,
|
|
||||||
model_type,
|
|
||||||
model_path,
|
|
||||||
model_relative_path,
|
|
||||||
model_hash,
|
|
||||||
model,
|
|
||||||
source_url,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if not model:
|
|
||||||
model = self._get_existing_model(
|
|
||||||
session, model_type, model_relative_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
model = Model(
|
|
||||||
path=model_relative_path,
|
|
||||||
type=model_type,
|
|
||||||
file_name=os.path.basename(model_path),
|
|
||||||
)
|
|
||||||
session.add(model)
|
|
||||||
|
|
||||||
model.file_size = self._get_file_size(model_path)
|
|
||||||
model.hash = model_hash
|
|
||||||
if model_hash:
|
|
||||||
model.hash_algorithm = "blake3"
|
|
||||||
model.source_url = source_url
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
return model
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
|
||||||
"""
|
|
||||||
Process a model file and update the database with metadata.
|
|
||||||
If the file already exists and matches the database, it will not be processed again.
|
|
||||||
Returns the model object or if an error occurs, returns None.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not can_create_session():
|
|
||||||
return
|
|
||||||
|
|
||||||
result = self._validate_path(model_path)
|
|
||||||
if not result:
|
|
||||||
return
|
|
||||||
model_type, model_relative_path = result
|
|
||||||
|
|
||||||
with create_session() as session:
|
|
||||||
session.expire_on_commit = False
|
|
||||||
|
|
||||||
existing_model = self._get_existing_model(
|
|
||||||
session, model_type, model_relative_path
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
existing_model
|
|
||||||
and existing_model.hash
|
|
||||||
and existing_model.file_size == self._get_file_size(model_path)
|
|
||||||
):
|
|
||||||
# File exists with hash and same size, no need to process
|
|
||||||
self._ensure_source_url(session, existing_model, source_url)
|
|
||||||
return existing_model
|
|
||||||
|
|
||||||
if model_hash:
|
|
||||||
model_hash = model_hash.lower()
|
|
||||||
logging.info(f"Using provided hash: {model_hash}")
|
|
||||||
else:
|
|
||||||
start_time = time.time()
|
|
||||||
logging.info(f"Hashing model {model_relative_path}")
|
|
||||||
model_hash = self._hash_file(model_path)
|
|
||||||
if not model_hash:
|
|
||||||
return
|
|
||||||
logging.info(
|
|
||||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._update_database(
|
|
||||||
session,
|
|
||||||
model_type,
|
|
||||||
model_path,
|
|
||||||
model_relative_path,
|
|
||||||
model_hash,
|
|
||||||
existing_model,
|
|
||||||
source_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
|
||||||
"""
|
|
||||||
Retrieve a model file from the database by hash and optionally by model type.
|
|
||||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not can_create_session():
|
|
||||||
return
|
|
||||||
|
|
||||||
dispose_session = False
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = create_session()
|
|
||||||
dispose_session = True
|
|
||||||
|
|
||||||
model = session.query(Model).filter(Model.hash == model_hash)
|
|
||||||
if model_type is not None:
|
|
||||||
model = model.filter(Model.type == model_type)
|
|
||||||
return model.first()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
if dispose_session:
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
def retrieve_hash(self, model_path, model_type=None):
|
|
||||||
"""
|
|
||||||
Retrieve the hash of a model file from the database.
|
|
||||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not can_create_session():
|
|
||||||
return
|
|
||||||
|
|
||||||
if model_type is not None:
|
|
||||||
result = self._validate_path(model_path)
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
model_type, model_relative_path = result
|
|
||||||
|
|
||||||
with create_session() as session:
|
|
||||||
model = self._get_existing_model(
|
|
||||||
session, model_type, model_relative_path
|
|
||||||
)
|
|
||||||
if model and model.hash:
|
|
||||||
return model.hash
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _validate_file_extension(self, file_name):
|
|
||||||
"""Validate that the file extension is supported."""
|
|
||||||
extension = os.path.splitext(file_name)[1]
|
|
||||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
|
||||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
|
||||||
|
|
||||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
|
||||||
"""Check if file exists and has correct hash."""
|
|
||||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
|
||||||
if self._file_exists(destination_path):
|
|
||||||
model = self.process_file(destination_path)
|
|
||||||
if model and (expected_hash is None or model.hash == expected_hash):
|
|
||||||
logging.debug(
|
|
||||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
|
||||||
)
|
|
||||||
return destination_path
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _check_existing_file_by_hash(self, hash, type, url):
|
|
||||||
"""Check if a file with the given hash exists in the database and on disk."""
|
|
||||||
hash = hash.lower()
|
|
||||||
with create_session() as session:
|
|
||||||
model = self.retrieve_model_by_hash(hash, type, session)
|
|
||||||
if model:
|
|
||||||
existing_path = get_full_path(type, model.path)
|
|
||||||
if existing_path:
|
|
||||||
logging.debug(
|
|
||||||
f"File {model.path} already exists in the database at {existing_path}"
|
|
||||||
)
|
|
||||||
self._ensure_source_url(session, model, url)
|
|
||||||
return existing_path
|
|
||||||
else:
|
|
||||||
logging.debug(
|
|
||||||
f"File {model.path} exists in the database but not on disk"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _download_file(self, url, destination_path, hasher):
|
|
||||||
"""Download a file and update the hasher with its contents."""
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
logging.info(f"Downloading {url} to {destination_path}")
|
|
||||||
|
|
||||||
with open(destination_path, "wb") as f:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
if total_size > 0:
|
|
||||||
pbar = comfy.utils.ProgressBar(total_size)
|
|
||||||
else:
|
|
||||||
pbar = None
|
|
||||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
|
||||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
|
||||||
hasher.update(chunk)
|
|
||||||
progress_bar.update(len(chunk))
|
|
||||||
if pbar:
|
|
||||||
pbar.update(len(chunk))
|
|
||||||
|
|
||||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
|
||||||
"""Verify that the downloaded file has the expected hash."""
|
|
||||||
if expected_hash is not None and calculated_hash != expected_hash:
|
|
||||||
self._remove_file(destination_path)
|
|
||||||
raise ValueError(
|
|
||||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _remove_file(self, file_path):
|
|
||||||
"""Remove a file from disk."""
|
|
||||||
os.remove(file_path)
|
|
||||||
|
|
||||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
|
||||||
"""
|
|
||||||
Ensure a model file is downloaded and has the correct hash.
|
|
||||||
Returns the path to the downloaded file.
|
|
||||||
"""
|
|
||||||
logging.debug(
|
|
||||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate file extension
|
|
||||||
self._validate_file_extension(desired_file_name)
|
|
||||||
|
|
||||||
# Check if file exists with correct hash
|
|
||||||
if hash:
|
|
||||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
|
||||||
if existing_path:
|
|
||||||
return existing_path
|
|
||||||
|
|
||||||
# Check if file exists locally
|
|
||||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
|
||||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
|
||||||
if existing_path:
|
|
||||||
return existing_path
|
|
||||||
|
|
||||||
# Download the file
|
|
||||||
hasher = self._get_hasher()
|
|
||||||
self._download_file(url, destination_path, hasher)
|
|
||||||
|
|
||||||
# Verify hash
|
|
||||||
calculated_hash = hasher.hexdigest()
|
|
||||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
|
||||||
|
|
||||||
# Update database
|
|
||||||
self.process_file(destination_path, url, calculated_hash)
|
|
||||||
|
|
||||||
# TODO: Notify frontend to reload models
|
|
||||||
|
|
||||||
return destination_path
|
|
||||||
|
|
||||||
|
|
||||||
model_processor = ModelProcessor()
|
|
||||||
0
app/storage/__init__.py
Normal file
0
app/storage/__init__.py
Normal file
72
app/storage/hashing.py
Normal file
72
app/storage/hashing.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import IO, Union
|
||||||
|
|
||||||
|
from blake3 import blake3
|
||||||
|
|
||||||
|
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
|
||||||
|
"""Hash an already-open binary file object by streaming in chunks.
|
||||||
|
- Seeks to the beginning before reading (if supported).
|
||||||
|
- Restores the original position afterward (if tell/seek are supported).
|
||||||
|
"""
|
||||||
|
if chunk_size <= 0:
|
||||||
|
chunk_size = DEFAULT_CHUNK
|
||||||
|
|
||||||
|
orig_pos = None
|
||||||
|
if hasattr(file_obj, "tell"):
|
||||||
|
orig_pos = file_obj.tell()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(file_obj, "seek"):
|
||||||
|
file_obj.seek(0)
|
||||||
|
|
||||||
|
h = blake3()
|
||||||
|
while True:
|
||||||
|
chunk = file_obj.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
h.update(chunk)
|
||||||
|
return h.hexdigest()
|
||||||
|
finally:
|
||||||
|
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||||
|
file_obj.seek(orig_pos)
|
||||||
|
|
||||||
|
|
||||||
|
def blake3_hash_sync(
|
||||||
|
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||||
|
chunk_size: int = DEFAULT_CHUNK,
|
||||||
|
) -> str:
|
||||||
|
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||||
|
- a filename (str/bytes) or PathLike
|
||||||
|
- an open binary file object
|
||||||
|
|
||||||
|
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||||
|
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||||
|
reading and will attempt to restore the original position afterward.
|
||||||
|
"""
|
||||||
|
if hasattr(fp, "read"):
|
||||||
|
return _hash_file_obj_sync(fp, chunk_size)
|
||||||
|
|
||||||
|
with open(os.fspath(fp), "rb") as f:
|
||||||
|
return _hash_file_obj_sync(f, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
|
async def blake3_hash(
|
||||||
|
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||||
|
chunk_size: int = DEFAULT_CHUNK,
|
||||||
|
) -> str:
|
||||||
|
"""Async wrapper for ``blake3_hash_sync``.
|
||||||
|
Uses a worker thread so the event loop remains responsive.
|
||||||
|
"""
|
||||||
|
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||||
|
if hasattr(fp, "read"):
|
||||||
|
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
|
||||||
|
|
||||||
|
def _worker() -> str:
|
||||||
|
with open(os.fspath(fp), "rb") as f:
|
||||||
|
return _hash_file_obj_sync(f, chunk_size)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_worker)
|
||||||
@ -1,110 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, TypedDict
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
class AssetMetadata(TypedDict):
|
|
||||||
device: Any
|
|
||||||
return_metadata: bool
|
|
||||||
subdir: str
|
|
||||||
download_url: str
|
|
||||||
|
|
||||||
class AssetInfo:
|
|
||||||
def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}):
|
|
||||||
self.hash = hash
|
|
||||||
self.name = name
|
|
||||||
self.tags = tags
|
|
||||||
self.metadata = metadata
|
|
||||||
|
|
||||||
|
|
||||||
class ReturnedAssetABC(ABC):
|
|
||||||
def __init__(self, mimetype: str):
|
|
||||||
self.mimetype = mimetype
|
|
||||||
|
|
||||||
|
|
||||||
class ModelReturnedAsset(ReturnedAssetABC):
|
|
||||||
def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None):
|
|
||||||
super().__init__("model")
|
|
||||||
self.state_dict = state_dict
|
|
||||||
self.metadata = metadata
|
|
||||||
|
|
||||||
|
|
||||||
class AssetResolverABC(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAssetResolver(AssetResolverABC):
|
|
||||||
def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC:
|
|
||||||
# currently only supports models - make sure models is in the tags
|
|
||||||
if "models" not in asset_info.tags:
|
|
||||||
return None
|
|
||||||
# TODO: if hash exists, call model processor to try to get info about model:
|
|
||||||
if asset_info.hash:
|
|
||||||
try:
|
|
||||||
from app.model_processor import model_processor
|
|
||||||
model_db = model_processor.retrieve_model_by_hash(asset_info.hash)
|
|
||||||
full_path = model_db.path
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Could not get model by hash with error: {e}")
|
|
||||||
# the good ol' bread and butter - folder_paths's keys as tags
|
|
||||||
folder_keys = folder_paths.folder_names_and_paths.keys()
|
|
||||||
parent_paths = []
|
|
||||||
for tag in asset_info.tags:
|
|
||||||
if tag in folder_keys:
|
|
||||||
parent_paths.append(tag)
|
|
||||||
# if subdir metadata and name exists, use that as the model name going forward
|
|
||||||
if "subdir" in asset_info.metadata and asset_info.name:
|
|
||||||
# if no matching parent paths, then something went wrong and should return None
|
|
||||||
if len(parent_paths) == 0:
|
|
||||||
return None
|
|
||||||
relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name)
|
|
||||||
# now we have the parent keys, we can try to get the local path
|
|
||||||
chosen_parent = None
|
|
||||||
full_path = None
|
|
||||||
for parent_path in parent_paths:
|
|
||||||
full_path = folder_paths.get_full_path(parent_path, relative_path)
|
|
||||||
if full_path:
|
|
||||||
chosen_parent = parent_path
|
|
||||||
break
|
|
||||||
if full_path is not None:
|
|
||||||
logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}")
|
|
||||||
# we know the path, so load the model and return it
|
|
||||||
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
|
||||||
# TODO: handle caching
|
|
||||||
return ModelReturnedAsset(state_dict, metadata)
|
|
||||||
# if just name exists, try to find model by name in all subdirs of parent paths
|
|
||||||
# TODO: this behavior should be configurable by user
|
|
||||||
if asset_info.name:
|
|
||||||
for parent_path in parent_paths:
|
|
||||||
filelist = folder_paths.get_filename_list(parent_path)
|
|
||||||
for file in filelist:
|
|
||||||
if os.path.basename(file) == asset_info.name:
|
|
||||||
full_path = folder_paths.get_full_path(parent_path, file)
|
|
||||||
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
|
||||||
# TODO: handle caching
|
|
||||||
return ModelReturnedAsset(state_dict, metadata)
|
|
||||||
# TODO: if download_url metadata exists, download the model and load it; this should be configurable by user
|
|
||||||
if asset_info.metadata.get("download_url", None):
|
|
||||||
...
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
resolvers: list[AssetResolverABC] = []
|
|
||||||
resolvers.append(LocalAssetResolver())
|
|
||||||
|
|
||||||
|
|
||||||
def resolve(asset_info: AssetInfo) -> Any:
|
|
||||||
global resolvers
|
|
||||||
for resolver in resolvers:
|
|
||||||
try:
|
|
||||||
to_return = resolver.resolve(asset_info)
|
|
||||||
if to_return is not None:
|
|
||||||
return resolver.resolve(asset_info)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}")
|
|
||||||
return None
|
|
||||||
@ -211,7 +211,7 @@ parser.add_argument(
|
|||||||
database_default_path = os.path.abspath(
|
database_default_path = os.path.abspath(
|
||||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||||
)
|
)
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
|
|||||||
@ -102,12 +102,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
else:
|
else:
|
||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
|
|
||||||
try:
|
|
||||||
from app.model_processor import model_processor
|
|
||||||
model_processor.process_file(ckpt)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error processing file {ckpt}: {e}")
|
|
||||||
|
|
||||||
return (sd, metadata) if return_metadata else sd
|
return (sd, metadata) if return_metadata else sd
|
||||||
|
|
||||||
def save_torch_file(sd, ckpt, metadata=None):
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
|
|||||||
@ -1,56 +0,0 @@
|
|||||||
from comfy_api.latest import io, ComfyExtension
|
|
||||||
import comfy.asset_management
|
|
||||||
import comfy.sd
|
|
||||||
import folder_paths
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class AssetTestNode(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="AssetTestNode",
|
|
||||||
is_experimental=True,
|
|
||||||
inputs=[
|
|
||||||
io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Model.Output(),
|
|
||||||
io.Clip.Output(),
|
|
||||||
io.Vae.Output(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, ckpt_name: str):
|
|
||||||
hash = None
|
|
||||||
# lets get the full path just so we can retrieve the hash from db, if exists
|
|
||||||
try:
|
|
||||||
full_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
|
||||||
if full_path is None:
|
|
||||||
raise Exception(f"Model {ckpt_name} not found")
|
|
||||||
from app.model_processor import model_processor
|
|
||||||
hash = model_processor.retrieve_hash(full_path)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Could not get model by hash with error: {e}")
|
|
||||||
subdir, name = os.path.split(ckpt_name)
|
|
||||||
asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir})
|
|
||||||
asset = comfy.asset_management.resolve(asset_info)
|
|
||||||
# /\ the stuff above should happen in execution code instead of inside the node
|
|
||||||
# \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?)
|
|
||||||
if asset is None:
|
|
||||||
raise Exception(f"Model {asset_info.name} not found")
|
|
||||||
assert isinstance(asset, comfy.asset_management.ModelReturnedAsset)
|
|
||||||
out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata)
|
|
||||||
return io.NodeOutput(out[0], out[1], out[2])
|
|
||||||
|
|
||||||
|
|
||||||
class AssetTestExtension(ComfyExtension):
|
|
||||||
@classmethod
|
|
||||||
async def get_node_list(cls):
|
|
||||||
return [AssetTestNode]
|
|
||||||
|
|
||||||
|
|
||||||
def comfy_entrypoint():
|
|
||||||
return AssetTestExtension()
|
|
||||||
9
main.py
9
main.py
@ -278,11 +278,11 @@ def cleanup_temp():
|
|||||||
if os.path.exists(temp_dir):
|
if os.path.exists(temp_dir):
|
||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
def setup_database():
|
async def setup_database():
|
||||||
try:
|
try:
|
||||||
from app.database.db import init_db, dependencies_available
|
from app.database.db import init_db_engine, dependencies_available
|
||||||
if dependencies_available():
|
if dependencies_available():
|
||||||
init_db()
|
await init_db_engine()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||||
|
|
||||||
@ -309,6 +309,8 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
asyncio.set_event_loop(asyncio_loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
prompt_server = server.PromptServer(asyncio_loop)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
|
||||||
|
asyncio_loop.run_until_complete(setup_database())
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||||
@ -317,7 +319,6 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
setup_database()
|
|
||||||
|
|
||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|||||||
7
nodes.py
7
nodes.py
@ -28,9 +28,10 @@ import comfy.sd
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.controlnet
|
import comfy.controlnet
|
||||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion
|
||||||
from comfy_api.version_list import supported_versions
|
from comfy_api.version_list import supported_versions
|
||||||
from comfy_api.latest import io, ComfyExtension
|
from comfy_api.latest import io, ComfyExtension
|
||||||
|
from app.assets_manager import add_local_asset
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
|
||||||
@ -777,6 +778,9 @@ class VAELoader:
|
|||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
sd = comfy.utils.load_torch_file(vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
|
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path
|
||||||
|
)
|
||||||
vae = comfy.sd.VAE(sd=sd)
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
return (vae,)
|
return (vae,)
|
||||||
@ -2321,7 +2325,6 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_edit_model.py",
|
"nodes_edit_model.py",
|
||||||
"nodes_tcfg.py",
|
"nodes_tcfg.py",
|
||||||
"nodes_context_windows.py",
|
"nodes_context_windows.py",
|
||||||
"nodes_assets_test.py",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -20,6 +20,7 @@ tqdm
|
|||||||
psutil
|
psutil
|
||||||
alembic
|
alembic
|
||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
|
aiosqlite
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
blake3
|
blake3
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
|
from app.api.assets_routes import register_assets_routes
|
||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
@ -183,6 +184,7 @@ class PromptServer():
|
|||||||
else args.front_end_root
|
else args.front_end_root
|
||||||
)
|
)
|
||||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||||
|
register_assets_routes(self.app)
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
|
|||||||
@ -1,62 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import struct
|
|
||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
from aiohttp import web
|
|
||||||
from unittest.mock import patch
|
|
||||||
from app.model_manager import ModelFileManager
|
|
||||||
|
|
||||||
pytestmark = (
|
|
||||||
pytest.mark.asyncio
|
|
||||||
) # This applies the asyncio mark to all test functions in the module
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_manager():
|
|
||||||
return ModelFileManager()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def app(model_manager):
|
|
||||||
app = web.Application()
|
|
||||||
routes = web.RouteTableDef()
|
|
||||||
model_manager.add_routes(routes)
|
|
||||||
app.add_routes(routes)
|
|
||||||
return app
|
|
||||||
|
|
||||||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
|
||||||
img = Image.new('RGB', (100, 100), 'white')
|
|
||||||
img_byte_arr = BytesIO()
|
|
||||||
img.save(img_byte_arr, format='PNG')
|
|
||||||
img_byte_arr.seek(0)
|
|
||||||
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
|
||||||
|
|
||||||
safetensors_file = tmp_path / "test_model.safetensors"
|
|
||||||
header_bytes = json.dumps({
|
|
||||||
"__metadata__": {
|
|
||||||
"ssmd_cover_images": json.dumps([img_b64])
|
|
||||||
}
|
|
||||||
}).encode('utf-8')
|
|
||||||
length_bytes = struct.pack('<Q', len(header_bytes))
|
|
||||||
with open(safetensors_file, 'wb') as f:
|
|
||||||
f.write(length_bytes)
|
|
||||||
f.write(header_bytes)
|
|
||||||
|
|
||||||
with patch('folder_paths.folder_names_and_paths', {
|
|
||||||
'test_folder': ([str(tmp_path)], None)
|
|
||||||
}):
|
|
||||||
client = await aiohttp_client(app)
|
|
||||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
|
||||||
|
|
||||||
# Verify response
|
|
||||||
assert response.status == 200
|
|
||||||
assert response.content_type == 'image/webp'
|
|
||||||
|
|
||||||
# Verify the response contains valid image data
|
|
||||||
img_bytes = BytesIO(await response.read())
|
|
||||||
img = Image.open(img_bytes)
|
|
||||||
assert img.format
|
|
||||||
assert img.format.lower() == 'webp'
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
img.close()
|
|
||||||
@ -1,253 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from app.model_processor import ModelProcessor
|
|
||||||
from app.database.models import Model, Base
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Test data constants
|
|
||||||
TEST_MODEL_TYPE = "checkpoints"
|
|
||||||
TEST_URL = "http://example.com/model.safetensors"
|
|
||||||
TEST_FILE_NAME = "model.safetensors"
|
|
||||||
TEST_EXPECTED_HASH = "abc123"
|
|
||||||
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
|
||||||
"""Helper to create a test model in the database."""
|
|
||||||
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
|
||||||
session.add(model)
|
|
||||||
session.commit()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def setup_mock_hash_calculation(model_processor, hash_value):
|
|
||||||
"""Helper to setup hash calculation mocks."""
|
|
||||||
mock_hash = MagicMock()
|
|
||||||
mock_hash.hexdigest.return_value = hash_value
|
|
||||||
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
|
||||||
"""Helper to verify model exists in database with correct attributes."""
|
|
||||||
db_model = session.query(Model).filter_by(path=file_name).first()
|
|
||||||
assert db_model is not None
|
|
||||||
if expected_hash:
|
|
||||||
assert db_model.hash == expected_hash
|
|
||||||
if expected_type:
|
|
||||||
assert db_model.type == expected_type
|
|
||||||
return db_model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db_engine():
|
|
||||||
# Configure in-memory database
|
|
||||||
engine = create_engine("sqlite:///:memory:")
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
yield engine
|
|
||||||
Base.metadata.drop_all(engine)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db_session(db_engine):
|
|
||||||
Session = sessionmaker(bind=db_engine)
|
|
||||||
session = Session()
|
|
||||||
yield session
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_get_relative_path():
|
|
||||||
with patch("app.model_processor.get_relative_path") as mock:
|
|
||||||
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_get_full_path():
|
|
||||||
with patch("app.model_processor.get_full_path") as mock:
|
|
||||||
mock.return_value = TEST_DESTINATION_PATH
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
|
||||||
with patch("app.model_processor.create_session", return_value=db_session):
|
|
||||||
with patch("app.model_processor.can_create_session", return_value=True):
|
|
||||||
processor = ModelProcessor()
|
|
||||||
# Setup test state
|
|
||||||
processor.removed_files = []
|
|
||||||
processor.downloaded_files = []
|
|
||||||
processor.file_exists = {}
|
|
||||||
|
|
||||||
def mock_download_file(url, destination_path, hasher):
|
|
||||||
processor.downloaded_files.append((url, destination_path))
|
|
||||||
processor.file_exists[destination_path] = True
|
|
||||||
# Simulate writing some data to the file
|
|
||||||
test_data = b"test data"
|
|
||||||
hasher.update(test_data)
|
|
||||||
|
|
||||||
def mock_remove_file(file_path):
|
|
||||||
processor.removed_files.append(file_path)
|
|
||||||
if file_path in processor.file_exists:
|
|
||||||
del processor.file_exists[file_path]
|
|
||||||
|
|
||||||
# Setup common patches
|
|
||||||
file_exists_patch = patch.object(
|
|
||||||
processor,
|
|
||||||
"_file_exists",
|
|
||||||
side_effect=lambda path: processor.file_exists.get(path, False),
|
|
||||||
)
|
|
||||||
file_size_patch = patch.object(
|
|
||||||
processor,
|
|
||||||
"_get_file_size",
|
|
||||||
side_effect=lambda path: (
|
|
||||||
1000 if processor.file_exists.get(path, False) else 0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
download_file_patch = patch.object(
|
|
||||||
processor, "_download_file", side_effect=mock_download_file
|
|
||||||
)
|
|
||||||
remove_file_patch = patch.object(
|
|
||||||
processor, "_remove_file", side_effect=mock_remove_file
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
file_exists_patch,
|
|
||||||
file_size_patch,
|
|
||||||
download_file_patch,
|
|
||||||
remove_file_patch,
|
|
||||||
):
|
|
||||||
yield processor
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_invalid_extension(model_processor):
|
|
||||||
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
|
||||||
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
|
||||||
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
|
||||||
# Ensure that a file with the same hash but from a different source is not downloaded again
|
|
||||||
SOURCE_URL = "https://example.com/other.sft"
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
result = model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == TEST_DESTINATION_PATH
|
|
||||||
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
||||||
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
|
||||||
# Ensure that a file with a different hash raises an error
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
|
||||||
model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_new_file(model_processor, db_session):
|
|
||||||
# Ensure that a new file is downloaded
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
|
||||||
|
|
||||||
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
|
||||||
result = model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == TEST_DESTINATION_PATH
|
|
||||||
assert len(model_processor.downloaded_files) == 1
|
|
||||||
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
|
||||||
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
|
||||||
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
|
||||||
# Ensure that download that results in a different hash raises an error
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
|
||||||
|
|
||||||
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match="Downloaded file hash .* does not match expected hash .*",
|
|
||||||
):
|
|
||||||
model_processor.ensure_downloaded(
|
|
||||||
TEST_MODEL_TYPE,
|
|
||||||
TEST_URL,
|
|
||||||
TEST_FILE_NAME,
|
|
||||||
TEST_EXPECTED_HASH,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(model_processor.removed_files) == 1
|
|
||||||
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
|
||||||
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
|
||||||
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_without_hash(model_processor, db_session):
|
|
||||||
# Test processing file without provided hash
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
|
||||||
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_model_by_hash(model_processor, db_session):
|
|
||||||
# Test retrieving model by hash
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
|
||||||
# Test retrieving model by hash and type
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
assert result.type == TEST_MODEL_TYPE
|
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_hash(model_processor, db_session):
|
|
||||||
# Test retrieving hash for existing model
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
with patch.object(
|
|
||||||
model_processor,
|
|
||||||
"_validate_path",
|
|
||||||
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
|
||||||
):
|
|
||||||
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
|
||||||
assert result == TEST_EXPECTED_HASH
|
|
||||||
|
|
||||||
|
|
||||||
def test_validate_file_extension_valid_extensions(model_processor):
|
|
||||||
# Test all valid file extensions
|
|
||||||
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
|
||||||
for ext in valid_extensions:
|
|
||||||
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_existing_without_source_url(model_processor, db_session):
|
|
||||||
# Test processing an existing file that needs its source URL updated
|
|
||||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
|
||||||
|
|
||||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
|
||||||
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result.hash == TEST_EXPECTED_HASH
|
|
||||||
assert result.source_url == TEST_URL
|
|
||||||
|
|
||||||
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
|
||||||
assert db_model.source_url == TEST_URL
|
|
||||||
Loading…
Reference in New Issue
Block a user