mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
Merge pull request #9545 from bigcat88/asset-management
[Assets] Initial implementation
This commit is contained in:
commit
46fdd636de
174
.github/workflows/test-assets.yml
vendored
Normal file
174
.github/workflows/test-assets.yml
vendored
Normal file
@ -0,0 +1,174 @@
|
||||
name: Asset System Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'app/**'
|
||||
- 'alembic_db/**'
|
||||
- 'tests-assets/**'
|
||||
- '.github/workflows/test-assets.yml'
|
||||
- 'requirements.txt'
|
||||
pull_request:
|
||||
branches: [master]
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
PIP_DISABLE_PIP_VERSION_CHECK: '1'
|
||||
PYTHONUNBUFFERED: '1'
|
||||
|
||||
jobs:
|
||||
sqlite:
|
||||
name: SQLite (${{ matrix.sqlite_mode }}) • Python ${{ matrix.python }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python: ['3.9', '3.12']
|
||||
sqlite_mode: ['memory', 'file']
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U pip wheel
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-aiohttp pytest-asyncio
|
||||
|
||||
- name: Set deterministic test base dir
|
||||
id: basedir
|
||||
shell: bash
|
||||
run: |
|
||||
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
|
||||
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
|
||||
mkdir -p "$BASE/logs"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE"
|
||||
|
||||
- name: Set DB URL for SQLite
|
||||
id: setdb
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ matrix.sqlite_mode }}" = "memory" ]; then
|
||||
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///:memory:" >> "$GITHUB_ENV"
|
||||
else
|
||||
DBFILE="$RUNNER_TEMP/assets-tests.sqlite"
|
||||
mkdir -p "$(dirname "$DBFILE")"
|
||||
echo "ASSETS_TEST_DB_URL=sqlite+aiosqlite:///$DBFILE" >> "$GITHUB_ENV"
|
||||
fi
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest tests-assets
|
||||
|
||||
- name: Show ComfyUI logs
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
|
||||
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
|
||||
ls -la "$ASSETS_TEST_LOGS" || true
|
||||
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "----- BEGIN $f -----"
|
||||
sed -n '1,400p' "$f"
|
||||
echo "----- END $f -----"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload ComfyUI logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: asset-logs-sqlite-${{ matrix.sqlite_mode }}-py${{ matrix.python }}
|
||||
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
|
||||
if-no-files-found: warn
|
||||
|
||||
postgres:
|
||||
name: PostgreSQL ${{ matrix.pgsql }} • Python ${{ matrix.python }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 40
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python: ['3.9', '3.12']
|
||||
pgsql: ['14', '16']
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:${{ matrix.pgsql }}
|
||||
env:
|
||||
POSTGRES_DB: assets
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd "pg_isready -U postgres -d assets"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 12
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -U pip wheel
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-aiohttp pytest-asyncio
|
||||
pip install greenlet psycopg
|
||||
|
||||
- name: Set deterministic test base dir
|
||||
id: basedir
|
||||
shell: bash
|
||||
run: |
|
||||
BASE="$RUNNER_TEMP/comfyui-assets-tests-${{ matrix.python }}-${{ matrix.sqlite_mode }}-${{ github.run_id }}-${{ github.run_attempt }}"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE" >> "$GITHUB_ENV"
|
||||
echo "ASSETS_TEST_LOGS=$BASE/logs" >> "$GITHUB_ENV"
|
||||
mkdir -p "$BASE/logs"
|
||||
echo "ASSETS_TEST_BASE_DIR=$BASE"
|
||||
|
||||
- name: Set DB URL for PostgreSQL
|
||||
shell: bash
|
||||
run: |
|
||||
echo "ASSETS_TEST_DB_URL=postgresql+psycopg://postgres:postgres@localhost:5432/assets" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Run tests
|
||||
run: python -m pytest tests-assets
|
||||
|
||||
- name: Show ComfyUI logs
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
echo "==== ASSETS_TEST_BASE_DIR: $ASSETS_TEST_BASE_DIR ===="
|
||||
echo "==== ASSETS_TEST_LOGS: $ASSETS_TEST_LOGS ===="
|
||||
ls -la "$ASSETS_TEST_LOGS" || true
|
||||
for f in "$ASSETS_TEST_LOGS"/stdout.log "$ASSETS_TEST_LOGS"/stderr.log; do
|
||||
if [ -f "$f" ]; then
|
||||
echo "----- BEGIN $f -----"
|
||||
sed -n '1,400p' "$f"
|
||||
echo "----- END $f -----"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload ComfyUI logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: asset-logs-pgsql-${{ matrix.pgsql }}-py${{ matrix.python }}
|
||||
path: ${{ env.ASSETS_TEST_LOGS }}/*.log
|
||||
if-no-files-found: warn
|
||||
187
alembic_db/versions/0001_assets.py
Normal file
187
alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,187 @@
|
||||
"""initial assets schema
|
||||
|
||||
Revision ID: 0001_assets
|
||||
Revises:
|
||||
Create Date: 2025-08-20 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
if op.get_bind().dialect.name == "postgresql":
|
||||
op.create_index(
|
||||
"uq_assets_hash_not_null",
|
||||
"assets",
|
||||
["hash"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("hash IS NOT NULL"),
|
||||
)
|
||||
else:
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON().with_variant(postgresql.JSONB(), 'postgresql'), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
if op.get_bind().dialect.name == "postgresql":
|
||||
op.drop_index("uq_assets_hash_not_null", table_name="assets")
|
||||
else:
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
@ -1,40 +0,0 @@
|
||||
"""init
|
||||
|
||||
Revision ID: e9c714da8d57
|
||||
Revises:
|
||||
Create Date: 2025-05-30 20:14:33.772039
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e9c714da8d57'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.create_table('model',
|
||||
sa.Column('type', sa.Text(), nullable=False),
|
||||
sa.Column('path', sa.Text(), nullable=False),
|
||||
sa.Column('file_name', sa.Text(), nullable=True),
|
||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
||||
sa.Column('hash', sa.Text(), nullable=True),
|
||||
sa.Column('hash_algorithm', sa.Text(), nullable=True),
|
||||
sa.Column('source_url', sa.Text(), nullable=True),
|
||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('type', 'path')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('model')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,5 @@
|
||||
from .api.assets_routes import register_assets_system
|
||||
from .assets_scanner import sync_seed_assets
|
||||
from .database.db import init_db_engine
|
||||
|
||||
__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"]
|
||||
225
app/_assets_helpers.py
Normal file
225
app/_assets_helpers.py
Normal file
@ -0,0 +1,225 @@
|
||||
import contextlib
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .api import schemas_in
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: Optional[tuple[int, str, str]] = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
|
||||
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
|
||||
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||
if ts is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_scan_id(root: schemas_in.RootType) -> str:
|
||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
543
app/api/assets_routes.py
Normal file
543
app/api/assets_routes.py
Normal file
@ -0,0 +1,543 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import ValidationError
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .. import assets_manager, assets_scanner, user_manager
|
||||
from . import schemas_in, schemas_out
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: Optional[user_manager.UserManager] = None
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = await assets_manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
qp = request.rel_url.query
|
||||
query_dict = {}
|
||||
if "include_tags" in qp:
|
||||
query_dict["include_tags"] = qp.getall("include_tags")
|
||||
if "exclude_tags" in qp:
|
||||
query_dict["exclude_tags"] = qp.getall("exclude_tags")
|
||||
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
|
||||
v = qp.get(k)
|
||||
if v is not None:
|
||||
query_dict[k] = v
|
||||
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = await assets_manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
resp.content_type = content_type
|
||||
resp.headers["Content-Disposition"] = cd
|
||||
return resp
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = await assets_manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: Optional[str] = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: Optional[str] = None
|
||||
user_metadata_raw: Optional[str] = None
|
||||
provided_hash: Optional[str] = None
|
||||
provided_hash_exists: Optional[bool] = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: Optional[str] = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = await assets_manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = await assets_manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = await assets_manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
LOGGER.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = await assets_manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.set_asset_preview(
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=body.preview_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (PermissionError, ValueError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"set_asset_preview failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = await assets_manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as ve:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await assets_manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
LOGGER.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/seed")
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
try:
|
||||
await assets_scanner.sync_seed_assets(body.roots)
|
||||
except Exception:
|
||||
LOGGER.exception("sync_seed_assets failed for roots=%s", body.roots)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response({"synced": True, "roots": body.roots}, status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/scan/schedule")
|
||||
async def schedule_asset_scan(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
try:
|
||||
body = schemas_in.ScheduleAssetScanBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
states = await assets_scanner.schedule_scans(body.roots)
|
||||
return web.json_response(states.model_dump(mode="json"), status=202)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/scan")
|
||||
async def get_asset_scan_status(request: web.Request) -> web.Response:
|
||||
root = request.query.get("root", "").strip().lower()
|
||||
states = assets_scanner.current_statuses()
|
||||
if root in {"models", "input", "output"}:
|
||||
states = [s for s in states.scans if s.root == root] # type: ignore
|
||||
states = schemas_out.AssetScanStatusResponse(scans=states)
|
||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
297
app/api/schemas_in.py
Normal file
297
app/api/schemas_in.py
Normal file
@ -0,0 +1,297 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: Optional[str] = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: Optional[dict[str, Any]] = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
user_metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.tags is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, tags, user_metadata.")
|
||||
if self.tags is not None:
|
||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||
raise ValueError("Field 'tags' must be an array of strings.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: Optional[str] = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: Optional[str] = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: Optional[str] = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
115
app/api/schemas_out.py
Normal file
115
app/api/schemas_out.py
Normal file
@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: Optional[datetime], _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanError(BaseModel):
|
||||
path: str
|
||||
message: str
|
||||
at: Optional[str] = Field(None, description="ISO timestamp")
|
||||
|
||||
|
||||
class AssetScanStatus(BaseModel):
|
||||
scan_id: str
|
||||
root: Literal["models", "input", "output"]
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
|
||||
scheduled_at: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
finished_at: Optional[str] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
file_errors: list[AssetScanError] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AssetScanStatusResponse(BaseModel):
|
||||
scans: list[AssetScanStatus] = Field(default_factory=list)
|
||||
556
app/assets_manager.py
Normal file
556
app/assets_manager.py
Normal file
@ -0,0 +1,556 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from comfy_api.internal import async_to_sync
|
||||
|
||||
from ._assets_helpers import (
|
||||
ensure_within_base,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.db import create_session
|
||||
from .database.models import Asset
|
||||
from .database.services import (
|
||||
add_tags_to_asset_info,
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
check_fs_asset_exists_quick,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_asset_tags,
|
||||
ingest_fs_asset,
|
||||
list_asset_infos_page,
|
||||
list_cache_states_by_asset_id,
|
||||
list_tags_with_usage,
|
||||
pick_best_live_path,
|
||||
remove_tags_from_asset_info,
|
||||
set_asset_info_preview,
|
||||
touch_asset_info_by_id,
|
||||
touch_asset_infos_by_fs_path,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .storage import hashing
|
||||
|
||||
|
||||
async def asset_exists(*, asset_hash: str) -> bool:
|
||||
async with await create_session() as session:
|
||||
return await asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
||||
if tags is None:
|
||||
tags = []
|
||||
try:
|
||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
||||
if not size_bytes:
|
||||
return
|
||||
|
||||
async with await create_session() as session:
|
||||
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||
await touch_asset_infos_by_fs_path(session, file_path=abs_path)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
asset_hash = hashing.blake3_hash_sync(abs_path)
|
||||
|
||||
async with await create_session() as session:
|
||||
await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash="blake3:" + asset_hash,
|
||||
abs_path=abs_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=None,
|
||||
info_name=file_name,
|
||||
tag_origin="automatic",
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def list_assets(
|
||||
*,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
async with await create_session() as session:
|
||||
infos, tag_map, total = await list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
async def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: Optional[str] = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
try:
|
||||
digest = await hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
async with await create_session() as session:
|
||||
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = await create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
async with await create_session() as session:
|
||||
result = await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
|
||||
|
||||
async def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = await update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
|
||||
|
||||
async def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
await set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = await session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
await session.delete(asset_row)
|
||||
|
||||
await session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
async def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> Optional[schemas_out.AssetCreated]:
|
||||
canonical = hash_str.strip().lower()
|
||||
async with await create_session() as session:
|
||||
asset = await get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = await create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
|
||||
async def list_tags(
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
async with await create_session() as session:
|
||||
rows, total = await list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
|
||||
|
||||
async def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = await add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
async def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = await remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: Optional[str]) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: Optional[str], fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
501
app/assets_scanner.py
Normal file
501
app/assets_scanner.py
Normal file
@ -0,0 +1,501 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import folder_paths
|
||||
|
||||
from ._assets_helpers import (
|
||||
collect_models_files,
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
list_tree,
|
||||
new_scan_id,
|
||||
prefixes_for_root,
|
||||
ts_to_iso,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.db import create_session
|
||||
from .database.helpers import (
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
fast_asset_file_check,
|
||||
remove_missing_tag_for_asset_id,
|
||||
seed_from_paths_batch,
|
||||
)
|
||||
from .database.models import Asset, AssetCacheState, AssetInfo
|
||||
from .database.services import (
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
list_cache_states_by_asset_id,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SLOW_HASH_CONCURRENCY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanProgress:
|
||||
scan_id: str
|
||||
root: schemas_in.RootType
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
||||
scheduled_at: float = field(default_factory=lambda: time.time())
|
||||
started_at: Optional[float] = None
|
||||
finished_at: Optional[float] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
file_errors: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlowQueueState:
|
||||
queue: asyncio.Queue
|
||||
workers: list[asyncio.Task] = field(default_factory=list)
|
||||
closed: bool = False
|
||||
|
||||
|
||||
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
|
||||
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
|
||||
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
scans = []
|
||||
for root in schemas_in.ALLOWED_ROOTS:
|
||||
prog = PROGRESS_BY_ROOT.get(root)
|
||||
if not prog:
|
||||
continue
|
||||
scans.append(_scan_progress_to_scan_status_model(prog))
|
||||
return schemas_out.AssetScanStatusResponse(scans=scans)
|
||||
|
||||
|
||||
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
|
||||
results: list[ScanProgress] = []
|
||||
for root in roots:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
state = SlowQueueState(queue=asyncio.Queue())
|
||||
SLOW_STATE_BY_ROOT[root] = state
|
||||
RUNNING_TASKS[root] = asyncio.create_task(
|
||||
_run_hash_verify_pipeline(root, prog, state),
|
||||
name=f"asset-scan:{root}",
|
||||
)
|
||||
results.append(prog)
|
||||
return _status_response_for(results)
|
||||
|
||||
|
||||
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
|
||||
t_total = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors = await _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as ex:
|
||||
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
ap = os.path.abspath(p)
|
||||
if ap in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
st = os.stat(ap, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not st.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(ap)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": ap,
|
||||
"size_bytes": st.st_size,
|
||||
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(ap),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
|
||||
if not specs:
|
||||
return
|
||||
async with await create_session() as sess:
|
||||
if tag_pool:
|
||||
await ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
await sess.commit()
|
||||
finally:
|
||||
LOGGER.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_total,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
|
||||
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
|
||||
|
||||
|
||||
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
|
||||
return schemas_out.AssetScanStatus(
|
||||
scan_id=progress.scan_id,
|
||||
root=progress.root,
|
||||
status=progress.status,
|
||||
scheduled_at=ts_to_iso(progress.scheduled_at),
|
||||
started_at=ts_to_iso(progress.started_at),
|
||||
finished_at=ts_to_iso(progress.finished_at),
|
||||
discovered=progress.discovered,
|
||||
processed=progress.processed,
|
||||
file_errors=[
|
||||
schemas_out.AssetScanError(
|
||||
path=e.get("path", ""),
|
||||
message=e.get("message", ""),
|
||||
at=e.get("at"),
|
||||
)
|
||||
for e in (progress.file_errors or [])
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
prog.status = "running"
|
||||
prog.started_at = time.time()
|
||||
try:
|
||||
prefixes = prefixes_for_root(root)
|
||||
|
||||
await _fast_db_consistency_pass(root)
|
||||
|
||||
# collect candidates from DB
|
||||
async with await create_session() as sess:
|
||||
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
# dedupe: prioritize verification first
|
||||
seen = set()
|
||||
ordered: list[int] = []
|
||||
for lst in (verify_ids, unhashed_ids):
|
||||
for sid in lst:
|
||||
if sid not in seen:
|
||||
seen.add(sid)
|
||||
ordered.append(sid)
|
||||
|
||||
prog.discovered = len(ordered)
|
||||
|
||||
# queue up work
|
||||
for sid in ordered:
|
||||
await state.queue.put(sid)
|
||||
state.closed = True
|
||||
_start_state_workers(root, prog, state)
|
||||
await _await_state_workers_then_finish(root, prog, state)
|
||||
except asyncio.CancelledError:
|
||||
prog.status = "cancelled"
|
||||
raise
|
||||
except Exception as exc:
|
||||
_append_error(prog, path="", message=str(exc))
|
||||
prog.status = "failed"
|
||||
prog.finished_at = time.time()
|
||||
LOGGER.exception("Asset scan failed for %s", root)
|
||||
finally:
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||
"""
|
||||
Detect missing files quickly and toggle 'missing' tag per asset_id.
|
||||
|
||||
Rules:
|
||||
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
|
||||
- We consider ALL cache states of the asset (across roots) before tagging.
|
||||
"""
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
elif root == "input":
|
||||
bases = [folder_paths.get_input_directory()]
|
||||
else:
|
||||
bases = [folder_paths.get_output_directory()]
|
||||
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# state + hash + size for the current root
|
||||
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
|
||||
|
||||
# Track fast_ok within the scanned root and whether the asset is hashed
|
||||
by_asset: dict[str, dict[str, bool]] = {}
|
||||
for state, a_hash, size_db in rows:
|
||||
aid = state.asset_id
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
|
||||
by_asset[aid] = acc
|
||||
try:
|
||||
if acc["hashed"]:
|
||||
st = os.stat(state.file_path, follow_symlinks=True)
|
||||
if fast_asset_file_check(mtime_db=state.mtime_ns, size_db=acc["size_db"], stat_result=st):
|
||||
acc["any_fast_ok_here"] = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
_append_error(prog, path=state.file_path, message=str(e))
|
||||
|
||||
# Decide per asset, considering ALL its states (not just this root)
|
||||
for aid, acc in by_asset.items():
|
||||
try:
|
||||
if not acc["hashed"]:
|
||||
# Never tag seed assets as missing
|
||||
continue
|
||||
|
||||
any_fast_ok_global = acc["any_fast_ok_here"]
|
||||
if not any_fast_ok_global:
|
||||
# Check other states outside this root
|
||||
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
|
||||
for st in others:
|
||||
try:
|
||||
any_fast_ok_global = fast_asset_file_check(
|
||||
mtime_db=st.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(st.file_path, follow_symlinks=True),
|
||||
)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
if any_fast_ok_global:
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
else:
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
except Exception as ex:
|
||||
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
|
||||
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
_append_error(prog, path="", message=f"reconcile failed: {e}")
|
||||
|
||||
|
||||
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
if state.workers:
|
||||
return
|
||||
|
||||
async def _worker(_wid: int):
|
||||
while True:
|
||||
sid = await state.queue.get()
|
||||
try:
|
||||
if sid is None:
|
||||
return
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# Optional: fetch path for better error messages
|
||||
st = await sess.get(AssetCacheState, sid)
|
||||
try:
|
||||
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
path = st.file_path if st else f"state:{sid}"
|
||||
_append_error(prog, path=path, message=str(e))
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
prog.processed += 1
|
||||
finally:
|
||||
state.queue.task_done()
|
||||
|
||||
state.workers = [
|
||||
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
|
||||
for i in range(SLOW_HASH_CONCURRENCY)
|
||||
]
|
||||
|
||||
async def _close_when_ready():
|
||||
while not state.closed:
|
||||
await asyncio.sleep(0.05)
|
||||
for _ in range(SLOW_HASH_CONCURRENCY):
|
||||
await state.queue.put(None)
|
||||
|
||||
asyncio.create_task(_close_when_ready())
|
||||
|
||||
|
||||
async def _await_state_workers_then_finish(
|
||||
root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState
|
||||
) -> None:
|
||||
if state.workers:
|
||||
await asyncio.gather(*state.workers, return_exceptions=True)
|
||||
await _reconcile_missing_tags_for_root(root, prog)
|
||||
prog.finished_at = time.time()
|
||||
prog.status = "completed"
|
||||
|
||||
|
||||
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
|
||||
prog.file_errors.append({
|
||||
"path": path,
|
||||
"message": message,
|
||||
"at": ts_to_iso(time.time()),
|
||||
})
|
||||
|
||||
|
||||
async def _fast_db_consistency_pass(
|
||||
root: schemas_in.RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> Optional[set[str]]:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
async with await create_session() as sess:
|
||||
rows = (
|
||||
await sess.execute(
|
||||
sa.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = await sess.get(Asset, aid)
|
||||
if asset:
|
||||
await sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
await sess.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
await sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
@ -1,112 +1,255 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
ENGINE: Optional[AsyncEngine] = None
|
||||
SESSION: Optional[async_sessionmaker] = None
|
||||
|
||||
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
def _root_paths():
|
||||
"""Resolve alembic.ini and migrations script folder."""
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
return config_path, scripts_path
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
|
||||
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
db_path: str = u.database or ""
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return str(u) # Do not touch SQLite URI databases like: "file:xxx?mode=memory&cache=shared"
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _normalize_sqlite_memory_url(db_url: str) -> tuple[str, bool]:
|
||||
"""
|
||||
If db_url points at an in-memory SQLite DB (":memory:" or file:... mode=memory),
|
||||
rewrite it to a *named* shared in-memory URI and ensure 'uri=true' is present.
|
||||
Returns: (normalized_url, is_memory)
|
||||
"""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url, False
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url, False
|
||||
|
||||
db = u.database or ""
|
||||
if db == ":memory:":
|
||||
u = u.set(database=f"file:comfyui_db_{os.getpid()}?mode=memory&cache=shared&uri=true")
|
||||
return str(u), True
|
||||
if isinstance(db, str) and db.startswith("file:") and "mode=memory" in db:
|
||||
if "uri=true" not in db:
|
||||
u = u.set(database=(db + ("&" if "?" in db else "?") + "uri=true"))
|
||||
return str(u), True
|
||||
return str(u), False
|
||||
|
||||
|
||||
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
"""Return the on-disk path for a SQLite URL, else None."""
|
||||
try:
|
||||
u = make_url(sync_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
db_path = u.database
|
||||
if isinstance(db_path, str) and db_path.startswith("file:"):
|
||||
return None # Not a real file if it is a URI like "file:...?"
|
||||
return db_path
|
||||
|
||||
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
"""Prepare Alembic Config with script location and DB URL."""
|
||||
config_path, scripts_path = _root_paths()
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||
return cfg
|
||||
|
||||
|
||||
async def init_db_engine() -> None:
|
||||
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||
|
||||
This must be called once on application startup before any DB usage.
|
||||
"""
|
||||
global ENGINE, SESSION
|
||||
|
||||
if ENGINE is not None:
|
||||
return
|
||||
|
||||
raw_url = args.database_url
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
db_url, is_mem = _normalize_sqlite_memory_url(raw_url)
|
||||
db_url = _absolutize_sqlite_url(db_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
if is_mem:
|
||||
connect_args["uri"] = True
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
connect_args=connect_args,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
if not is_mem:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(database_url=db_url, connect_args=connect_args)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
return _DB_AVAILABLE
|
||||
async def _run_migrations(database_url: str, connect_args: dict) -> None:
|
||||
if database_url.find("postgresql+psycopg") == -1:
|
||||
"""SQLite: Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(database_url)
|
||||
driver = u.drivername
|
||||
if not driver.startswith("sqlite+aiosqlite"):
|
||||
raise ValueError(f"Unsupported DB driver: {driver}")
|
||||
database_url, is_mem = _normalize_sqlite_memory_url(str(u.set(drivername="sqlite")))
|
||||
database_url = _absolutize_sqlite_url(database_url)
|
||||
|
||||
cfg = _get_alembic_config(database_url)
|
||||
engine = create_engine(database_url, future=True, connect_args=connect_args)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
config = get_alembic_config()
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
LOGGER.warning("Alembic: no target revision found.")
|
||||
return
|
||||
|
||||
if current_rev == target_rev:
|
||||
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||
return
|
||||
|
||||
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||
|
||||
# Optional backup for SQLite file DBs
|
||||
backup_path = None
|
||||
sqlite_path = _get_sqlite_file_path(database_url)
|
||||
if sqlite_path and os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".bkp"
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
shutil.copy(sqlite_path, backup_path)
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, target_rev)
|
||||
except Exception:
|
||||
if backup_path and os.path.exists(backup_path):
|
||||
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||
try:
|
||||
shutil.copy(backup_path, sqlite_path) # restore
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
except Exception as re:
|
||||
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||
else:
|
||||
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||
raise
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
def get_engine():
|
||||
"""Return the global async engine (initialized after init_db_engine())."""
|
||||
if ENGINE is None:
|
||||
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||
return ENGINE
|
||||
|
||||
|
||||
def get_session_maker():
|
||||
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||
if SESSION is None:
|
||||
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||
return SESSION
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope():
|
||||
"""Async context manager for a unit of work:
|
||||
|
||||
async with session_scope() as sess:
|
||||
... use sess ...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
async with maker() as sess:
|
||||
try:
|
||||
yield sess
|
||||
await sess.commit()
|
||||
except Exception:
|
||||
await sess.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def create_session():
|
||||
"""Convenience helper to acquire a single AsyncSession instance.
|
||||
|
||||
Typical usage:
|
||||
async with (await create_session()) as sess:
|
||||
...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
return maker()
|
||||
|
||||
25
app/database/helpers/__init__.py
Normal file
25
app/database/helpers/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
from .bulk_ops import seed_from_paths_batch
|
||||
from .escape_like import escape_like_prefix
|
||||
from .fast_check import fast_asset_file_check
|
||||
from .filters import apply_metadata_filter, apply_tag_filters
|
||||
from .ownership import visible_owner_clause
|
||||
from .projection import is_scalar, project_kv
|
||||
from .tags import (
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"apply_tag_filters",
|
||||
"apply_metadata_filter",
|
||||
"escape_like_prefix",
|
||||
"fast_asset_file_check",
|
||||
"is_scalar",
|
||||
"project_kv",
|
||||
"ensure_tags_exist",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"seed_from_paths_batch",
|
||||
"visible_owner_clause",
|
||||
]
|
||||
230
app/database/helpers/bulk_ops.py
Normal file
230
app/database/helpers/bulk_ops.py
Normal file
@ -0,0 +1,230 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
async def seed_from_paths_batch(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
specs: Sequence[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect not in ("sqlite", "postgresql"):
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
await session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
if dialect == "sqlite":
|
||||
ins_state = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
else:
|
||||
ins_state = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
if dialect == "sqlite":
|
||||
ins_info = (
|
||||
d_sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
else:
|
||||
ins_info = (
|
||||
d_pg.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
async def bulk_insert_tags_and_meta(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
dialect = session.bind.dialect.name
|
||||
if tag_rows:
|
||||
if dialect == "sqlite":
|
||||
ins_links = (
|
||||
d_sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins_links = (
|
||||
d_pg.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
await session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
if dialect == "sqlite":
|
||||
ins_meta = (
|
||||
d_sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins_meta = (
|
||||
d_pg.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
await session.execute(ins_meta, chunk)
|
||||
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
7
app/database/helpers/escape_like.py
Normal file
7
app/database/helpers/escape_like.py
Normal file
@ -0,0 +1,7 @@
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
19
app/database/helpers/fast_check.py
Normal file
19
app/database/helpers/fast_check.py
Normal file
@ -0,0 +1,19 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: Optional[int],
|
||||
size_db: Optional[int],
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
87
app/database/helpers/filters.py
Normal file
87
app/database/helpers/filters.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
12
app/database/helpers/ownership.py
Normal file
12
app/database/helpers/ownership.py
Normal file
@ -0,0 +1,12 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from ..models import AssetInfo
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
64
app/database/helpers/projection.py
Normal file
64
app/database/helpers/projection.py
Normal file
@ -0,0 +1,64 @@
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
90
app/database/helpers/tags.py
Normal file
90
app/database/helpers/tags.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
|
||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
await session.execute(ins)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
await session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
@ -1,59 +1,250 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
|
||||
|
||||
Base = declarative_base()
|
||||
from .timeutil import utcnow
|
||||
|
||||
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
sqlalchemy model representing a model file in the system.
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
This class defines the database schema for storing information about model files,
|
||||
including their type, path, hash, and when they were added to the system.
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
Attributes:
|
||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
||||
path (Text): The file path of the model relative to the type folder (primary key)
|
||||
file_name (Text): The name of the model file
|
||||
file_size (Integer): The size of the model file in bytes
|
||||
hash (Text): A hash of the model file
|
||||
hash_algorithm (Text): The algorithm used to generate the hash
|
||||
source_url (Text): The URL of the model file
|
||||
date_added (DateTime): Timestamp of when the model was added to the system
|
||||
"""
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__tablename__ = "model"
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
file_name = Column(Text)
|
||||
file_size = Column(Integer)
|
||||
hash = Column(Text)
|
||||
hash_algorithm = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
cache_states: Mapped[list["AssetCacheState"]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the model instance to a dictionary representation.
|
||||
__table_args__ = (
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the attributes of the model
|
||||
"""
|
||||
dict = to_dict(self)
|
||||
return dict
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Optional[Asset]] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Optional[Any]] = mapped_column(JSONB_V, nullable=True)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
57
app/database/services/__init__.py
Normal file
57
app/database/services/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
from .content import (
|
||||
check_fs_asset_exists_quick,
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
ingest_fs_asset,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
redirect_all_references_then_delete_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
)
|
||||
from .info import (
|
||||
add_tags_to_asset_info,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
replace_asset_info_metadata_projection,
|
||||
set_asset_info_preview,
|
||||
set_asset_info_tags,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_cache_state_by_asset_id,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# queries
|
||||
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
|
||||
"get_cache_state_by_asset_id",
|
||||
"list_cache_states_by_asset_id",
|
||||
"pick_best_live_path",
|
||||
# info
|
||||
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
|
||||
"update_asset_info_full", "replace_asset_info_metadata_projection",
|
||||
"touch_asset_info_by_id", "delete_asset_info_by_id",
|
||||
"add_tags_to_asset_info", "remove_tags_from_asset_info",
|
||||
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
|
||||
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
|
||||
# content
|
||||
"check_fs_asset_exists_quick",
|
||||
"redirect_all_references_then_delete_asset",
|
||||
"compute_hash_and_dedup_for_cache_state",
|
||||
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
|
||||
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
|
||||
"list_cache_states_with_asset_under_prefixes",
|
||||
]
|
||||
721
app/database/services/content.py
Normal file
721
app/database/services/content.py
Normal file
@ -0,0 +1,721 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from ..._assets_helpers import compute_relative_filename
|
||||
from ...storage import hashing as hashing_mod
|
||||
from ..helpers import (
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .info import replace_asset_info_metadata_projection
|
||||
from .queries import list_cache_states_by_asset_id, pick_best_live_path
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Returns True if we already track this absolute path with a HASHED asset and the cached mtime/size match."""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = (
|
||||
sa.select(sa.literal(True))
|
||||
.select_from(AssetCacheState)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(
|
||||
AssetCacheState.file_path == locator,
|
||||
Asset.hash.isnot(None),
|
||||
AssetCacheState.needs_verify.is_(False),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
conds = []
|
||||
if mtime_ns is not None:
|
||||
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||
if size_bytes is not None:
|
||||
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
if conds:
|
||||
stmt = stmt.where(*conds)
|
||||
return (await session.execute(stmt)).first() is not None
|
||||
|
||||
|
||||
async def redirect_all_references_then_delete_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
duplicate_asset_id: str,
|
||||
canonical_asset_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
|
||||
|
||||
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
|
||||
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
|
||||
- Otherwise, simply repoint the AssetInfo.asset_id.
|
||||
- Always retarget AssetCacheState rows.
|
||||
- Finally delete the duplicate Asset row.
|
||||
"""
|
||||
if duplicate_asset_id == canonical_asset_id:
|
||||
return
|
||||
|
||||
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
|
||||
dup_infos = (
|
||||
await session.execute(
|
||||
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
|
||||
)
|
||||
).unique().scalars().all()
|
||||
|
||||
for info in dup_infos:
|
||||
# Try to find an existing collision on canonical
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == canonical_asset_id,
|
||||
AssetInfo.owner_id == info.owner_id,
|
||||
AssetInfo.name == info.name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
|
||||
if existing:
|
||||
merged_meta = dict(existing.user_metadata or {})
|
||||
other_meta = info.user_metadata or {}
|
||||
for k, v in other_meta.items():
|
||||
if k not in merged_meta:
|
||||
merged_meta[k] = v
|
||||
if merged_meta != (existing.user_metadata or {}):
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=existing.id,
|
||||
user_metadata=merged_meta,
|
||||
)
|
||||
|
||||
existing_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
from_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
to_add = sorted(from_tags - existing_tags)
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
now = utcnow()
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if existing.preview_id is None and info.preview_id is not None:
|
||||
existing.preview_id = info.preview_id
|
||||
if info.last_access_time and (
|
||||
existing.last_access_time is None or info.last_access_time > existing.last_access_time
|
||||
):
|
||||
existing.last_access_time = info.last_access_time
|
||||
existing.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
|
||||
await session.delete(info)
|
||||
await session.flush()
|
||||
else:
|
||||
# Simple retarget
|
||||
info.asset_id = canonical_asset_id
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# 2) Repoint cache states and previews
|
||||
await session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == duplicate_asset_id)
|
||||
.values(asset_id=canonical_asset_id)
|
||||
)
|
||||
await session.execute(
|
||||
sa.update(AssetInfo)
|
||||
.where(AssetInfo.preview_id == duplicate_asset_id)
|
||||
.values(preview_id=canonical_asset_id)
|
||||
)
|
||||
|
||||
# 3) Remove duplicate Asset
|
||||
dup = await session.get(Asset, duplicate_asset_id)
|
||||
if dup:
|
||||
await session.delete(dup)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def compute_hash_and_dedup_for_cache_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
state_id: int,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compute hash for the given cache state, deduplicate, and settle verify cases.
|
||||
|
||||
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
|
||||
"""
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if not state:
|
||||
return None
|
||||
|
||||
path = state.file_path
|
||||
try:
|
||||
if not os.path.isfile(path):
|
||||
# File vanished: drop the state. If the Asset has hash=NULL and has no other states, drop the Asset too.
|
||||
asset = await session.get(Asset, state.asset_id)
|
||||
await session.delete(state)
|
||||
await session.flush()
|
||||
|
||||
if asset and asset.hash is None:
|
||||
remaining = (
|
||||
await session.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset.id)
|
||||
)
|
||||
).scalar_one()
|
||||
if int(remaining or 0) == 0:
|
||||
await session.delete(asset)
|
||||
await session.flush()
|
||||
else:
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=asset.id)
|
||||
return None
|
||||
|
||||
digest = await hashing_mod.blake3_hash(path)
|
||||
new_hash = f"blake3:{digest}"
|
||||
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
new_size = int(st.st_size)
|
||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
# Current asset of this state
|
||||
this_asset = await session.get(Asset, state.asset_id)
|
||||
|
||||
# If the state got orphaned somehow (race), just reattach appropriately.
|
||||
if not this_asset:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
state.asset_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
state.asset_id = new_asset.id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
|
||||
await session.flush()
|
||||
return state.asset_id
|
||||
|
||||
# 1) Seed asset case (hash is NULL): claim or merge into canonical
|
||||
if this_asset.hash is None:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
# Merge seed asset into canonical (safe, collision-aware)
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
|
||||
# No canonical: try to claim the hash; handle races with a SAVEPOINT
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
this_asset.hash = new_hash
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Someone else claimed it concurrently; fetch canonical and merge
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=canonical.id)
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
# If we got here, the integrity error was not about hash uniqueness
|
||||
raise
|
||||
|
||||
# Claimed successfully
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# 2) Verify case for hashed assets
|
||||
if this_asset.hash == new_hash:
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=this_asset.id)
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
target_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
target_id = new_asset.id
|
||||
|
||||
state.asset_id = target_id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
with contextlib.suppress(Exception):
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
|
||||
await _recompute_and_apply_filename_for_asset(session, asset_id=target_id)
|
||||
await session.flush()
|
||||
return target_id
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
async def list_unhashed_candidates_under_prefixes(session: AsyncSession, *, prefixes: list[str]) -> list[int]:
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
path_filter = sa.or_(*conds) if len(conds) > 1 else conds[0]
|
||||
if session.bind.dialect.name == "postgresql":
|
||||
stmt = (
|
||||
sa.select(AssetCacheState.id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None), path_filter)
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
.distinct(AssetCacheState.asset_id)
|
||||
)
|
||||
else:
|
||||
first_id = sa.func.min(AssetCacheState.id).label("first_id")
|
||||
stmt = (
|
||||
sa.select(first_id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None), path_filter)
|
||||
.group_by(AssetCacheState.asset_id)
|
||||
.order_by(first_id.asc())
|
||||
)
|
||||
return [int(x) for x in (await session.execute(stmt)).scalars().all()]
|
||||
|
||||
|
||||
async def list_verify_candidates_under_prefixes(
|
||||
session: AsyncSession, *, prefixes: Sequence[str]
|
||||
) -> Union[list[int], Sequence[int]]:
|
||||
if not prefixes:
|
||||
return []
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
return (
|
||||
await session.execute(
|
||||
sa.select(AssetCacheState.id)
|
||||
.where(AssetCacheState.needs_verify.is_(True))
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
preview_id: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
dialect = session.bind.dialect.name
|
||||
|
||||
if preview_id:
|
||||
if not await session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
res = await session.execute(
|
||||
d_sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
await session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
elif dialect == "postgresql":
|
||||
res = await session.execute(
|
||||
d_pg.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[Asset.hash],
|
||||
index_where=Asset.__table__.c.hash.isnot(None),
|
||||
)
|
||||
.returning(Asset.id)
|
||||
)
|
||||
inserted_id = res.scalar_one_or_none()
|
||||
if inserted_id:
|
||||
out["asset_created"] = True
|
||||
asset = await session.get(Asset, inserted_id)
|
||||
else:
|
||||
asset = (
|
||||
await session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
res = await session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = await session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
await session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
locator = os.path.abspath(file_path)
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(AssetCacheState)
|
||||
.where(
|
||||
AssetCacheState.asset_id == AssetInfo.asset_id,
|
||||
AssetCacheState.file_path == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
await session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
async def list_cache_states_with_asset_under_prefixes(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefixes: Sequence[str],
|
||||
) -> list[tuple[AssetCacheState, Optional[str], int]]:
|
||||
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
if not p:
|
||||
continue
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base = base + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
if not conds:
|
||||
return []
|
||||
|
||||
rows = (
|
||||
await session.execute(
|
||||
select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
|
||||
|
||||
|
||||
async def _recompute_and_apply_filename_for_asset(session: AsyncSession, *, asset_id: str) -> None:
|
||||
"""Compute filename from the first *existing* cache state path and apply it to all AssetInfo (if changed)."""
|
||||
try:
|
||||
primary_path = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset_id))
|
||||
if not primary_path:
|
||||
return
|
||||
new_filename = compute_relative_filename(primary_path)
|
||||
if not new_filename:
|
||||
return
|
||||
infos = (
|
||||
await session.execute(select(AssetInfo).where(AssetInfo.asset_id == asset_id))
|
||||
).scalars().all()
|
||||
for info in infos:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") == new_filename:
|
||||
continue
|
||||
updated = dict(current_meta)
|
||||
updated["filename"] = new_filename
|
||||
await replace_asset_info_metadata_projection(session, asset_info_id=info.id, user_metadata=updated)
|
||||
except Exception:
|
||||
logging.exception("Failed to recompute filename metadata for asset %s", asset_id)
|
||||
586
app/database/services/info.py
Normal file
586
app/database/services/info.py
Normal file
@ -0,0 +1,586 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import contains_eager, noload
|
||||
|
||||
from ..._assets_helpers import compute_relative_filename, normalize_tags
|
||||
from ..helpers import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
ensure_tags_exist,
|
||||
escape_like_prefix,
|
||||
project_kv,
|
||||
visible_owner_clause,
|
||||
)
|
||||
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .queries import (
|
||||
get_asset_by_hash,
|
||||
list_cache_states_by_asset_id,
|
||||
pick_best_live_path,
|
||||
)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
owner_id: str = "",
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (await session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = await session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def fetch_asset_info_asset_and_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (await session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
async def create_asset_info_for_existing_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(await list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: Optional[dict],
|
||||
) -> None:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def touch_asset_info_by_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
await session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((await session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
async def add_tags_to_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
async with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
await nested.rollback()
|
||||
|
||||
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
async def remove_tags_from_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (await session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
async def set_asset_info_preview(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not await session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
76
app/database/services/queries.py
Normal file
76
app/database/services/queries.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
row = (
|
||||
await session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
|
||||
return (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (await session.execute(q)).first() is not None
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def list_cache_states_by_asset_id(
|
||||
session: AsyncSession, *, asset_id: str
|
||||
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def pick_best_live_path(states: Union[list[AssetCacheState], Sequence[AssetCacheState]]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
6
app/database/timeutil.py
Normal file
6
app/database/timeutil.py
Normal file
@ -0,0 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
@ -1,331 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from folder_paths import get_relative_path, get_full_path
|
||||
from app.database.db import create_session, dependencies_available, can_create_session
|
||||
import blake3
|
||||
import comfy.utils
|
||||
|
||||
|
||||
if dependencies_available():
|
||||
from app.database.models import Model
|
||||
|
||||
|
||||
class ModelProcessor:
|
||||
def _validate_path(self, model_path):
|
||||
try:
|
||||
if not self._file_exists(model_path):
|
||||
logging.error(f"Model file not found: {model_path}")
|
||||
return None
|
||||
|
||||
result = get_relative_path(model_path)
|
||||
if not result:
|
||||
logging.error(
|
||||
f"Model file not in a recognized model directory: {model_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _file_exists(self, path):
|
||||
"""Check if a file exists."""
|
||||
return os.path.exists(path)
|
||||
|
||||
def _get_file_size(self, path):
|
||||
"""Get file size."""
|
||||
return os.path.getsize(path)
|
||||
|
||||
def _get_hasher(self):
|
||||
return blake3.blake3()
|
||||
|
||||
def _hash_file(self, model_path):
|
||||
try:
|
||||
hasher = self._get_hasher()
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
||||
return (
|
||||
session.query(Model)
|
||||
.filter(Model.type == model_type)
|
||||
.filter(Model.path == model_relative_path)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _ensure_source_url(self, session, model, source_url):
|
||||
if model.source_url is None:
|
||||
model.source_url = source_url
|
||||
session.commit()
|
||||
|
||||
def _update_database(
|
||||
self,
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
model,
|
||||
source_url,
|
||||
):
|
||||
try:
|
||||
if not model:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = Model(
|
||||
path=model_relative_path,
|
||||
type=model_type,
|
||||
file_name=os.path.basename(model_path),
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
model.file_size = self._get_file_size(model_path)
|
||||
model.hash = model_hash
|
||||
if model_hash:
|
||||
model.hash_algorithm = "blake3"
|
||||
model.source_url = source_url
|
||||
|
||||
session.commit()
|
||||
return model
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
||||
)
|
||||
|
||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
||||
"""
|
||||
Process a model file and update the database with metadata.
|
||||
If the file already exists and matches the database, it will not be processed again.
|
||||
Returns the model object or if an error occurs, returns None.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
session.expire_on_commit = False
|
||||
|
||||
existing_model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if (
|
||||
existing_model
|
||||
and existing_model.hash
|
||||
and existing_model.file_size == self._get_file_size(model_path)
|
||||
):
|
||||
# File exists with hash and same size, no need to process
|
||||
self._ensure_source_url(session, existing_model, source_url)
|
||||
return existing_model
|
||||
|
||||
if model_hash:
|
||||
model_hash = model_hash.lower()
|
||||
logging.info(f"Using provided hash: {model_hash}")
|
||||
else:
|
||||
start_time = time.time()
|
||||
logging.info(f"Hashing model {model_relative_path}")
|
||||
model_hash = self._hash_file(model_path)
|
||||
if not model_hash:
|
||||
return
|
||||
logging.info(
|
||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
||||
)
|
||||
|
||||
return self._update_database(
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
existing_model,
|
||||
source_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
||||
"""
|
||||
Retrieve a model file from the database by hash and optionally by model type.
|
||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
dispose_session = False
|
||||
|
||||
if session is None:
|
||||
session = create_session()
|
||||
dispose_session = True
|
||||
|
||||
model = session.query(Model).filter(Model.hash == model_hash)
|
||||
if model_type is not None:
|
||||
model = model.filter(Model.type == model_type)
|
||||
return model.first()
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
if dispose_session:
|
||||
session.close()
|
||||
|
||||
def retrieve_hash(self, model_path, model_type=None):
|
||||
"""
|
||||
Retrieve the hash of a model file from the database.
|
||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
if model_type is not None:
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return None
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if model and model.hash:
|
||||
return model.hash
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _validate_file_extension(self, file_name):
|
||||
"""Validate that the file extension is supported."""
|
||||
extension = os.path.splitext(file_name)[1]
|
||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
||||
|
||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
||||
"""Check if file exists and has correct hash."""
|
||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
||||
if self._file_exists(destination_path):
|
||||
model = self.process_file(destination_path)
|
||||
if model and (expected_hash is None or model.hash == expected_hash):
|
||||
logging.debug(
|
||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
||||
)
|
||||
return destination_path
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_existing_file_by_hash(self, hash, type, url):
|
||||
"""Check if a file with the given hash exists in the database and on disk."""
|
||||
hash = hash.lower()
|
||||
with create_session() as session:
|
||||
model = self.retrieve_model_by_hash(hash, type, session)
|
||||
if model:
|
||||
existing_path = get_full_path(type, model.path)
|
||||
if existing_path:
|
||||
logging.debug(
|
||||
f"File {model.path} already exists in the database at {existing_path}"
|
||||
)
|
||||
self._ensure_source_url(session, model, url)
|
||||
return existing_path
|
||||
else:
|
||||
logging.debug(
|
||||
f"File {model.path} exists in the database but not on disk"
|
||||
)
|
||||
return None
|
||||
|
||||
def _download_file(self, url, destination_path, hasher):
|
||||
"""Download a file and update the hasher with its contents."""
|
||||
response = requests.get(url, stream=True)
|
||||
logging.info(f"Downloading {url} to {destination_path}")
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
if total_size > 0:
|
||||
pbar = comfy.utils.ProgressBar(total_size)
|
||||
else:
|
||||
pbar = None
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
if pbar:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
||||
"""Verify that the downloaded file has the expected hash."""
|
||||
if expected_hash is not None and calculated_hash != expected_hash:
|
||||
self._remove_file(destination_path)
|
||||
raise ValueError(
|
||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
||||
)
|
||||
|
||||
def _remove_file(self, file_path):
|
||||
"""Remove a file from disk."""
|
||||
os.remove(file_path)
|
||||
|
||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
||||
"""
|
||||
Ensure a model file is downloaded and has the correct hash.
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
logging.debug(
|
||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
self._validate_file_extension(desired_file_name)
|
||||
|
||||
# Check if file exists with correct hash
|
||||
if hash:
|
||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Check if file exists locally
|
||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Download the file
|
||||
hasher = self._get_hasher()
|
||||
self._download_file(url, destination_path, hasher)
|
||||
|
||||
# Verify hash
|
||||
calculated_hash = hasher.hexdigest()
|
||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
||||
|
||||
# Update database
|
||||
self.process_file(destination_path, url, calculated_hash)
|
||||
|
||||
# TODO: Notify frontend to reload models
|
||||
|
||||
return destination_path
|
||||
|
||||
|
||||
model_processor = ModelProcessor()
|
||||
0
app/storage/__init__.py
Normal file
0
app/storage/__init__.py
Normal file
72
app/storage/hashing.py
Normal file
72
app/storage/hashing.py
Normal file
@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
|
||||
"""Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
orig_pos = None
|
||||
if hasattr(file_obj, "tell"):
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
if hasattr(file_obj, "seek"):
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||
file_obj.seek(orig_pos)
|
||||
|
||||
|
||||
def blake3_hash_sync(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj_sync(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
@ -1,110 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
import os
|
||||
import logging
|
||||
import comfy.utils
|
||||
|
||||
import folder_paths
|
||||
|
||||
class AssetMetadata(TypedDict):
|
||||
device: Any
|
||||
return_metadata: bool
|
||||
subdir: str
|
||||
download_url: str
|
||||
|
||||
class AssetInfo:
|
||||
def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}):
|
||||
self.hash = hash
|
||||
self.name = name
|
||||
self.tags = tags
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class ReturnedAssetABC(ABC):
|
||||
def __init__(self, mimetype: str):
|
||||
self.mimetype = mimetype
|
||||
|
||||
|
||||
class ModelReturnedAsset(ReturnedAssetABC):
|
||||
def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None):
|
||||
super().__init__("model")
|
||||
self.state_dict = state_dict
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class AssetResolverABC(ABC):
|
||||
@abstractmethod
|
||||
def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC:
|
||||
...
|
||||
|
||||
|
||||
class LocalAssetResolver(AssetResolverABC):
|
||||
def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC:
|
||||
# currently only supports models - make sure models is in the tags
|
||||
if "models" not in asset_info.tags:
|
||||
return None
|
||||
# TODO: if hash exists, call model processor to try to get info about model:
|
||||
if asset_info.hash:
|
||||
try:
|
||||
from app.model_processor import model_processor
|
||||
model_db = model_processor.retrieve_model_by_hash(asset_info.hash)
|
||||
full_path = model_db.path
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get model by hash with error: {e}")
|
||||
# the good ol' bread and butter - folder_paths's keys as tags
|
||||
folder_keys = folder_paths.folder_names_and_paths.keys()
|
||||
parent_paths = []
|
||||
for tag in asset_info.tags:
|
||||
if tag in folder_keys:
|
||||
parent_paths.append(tag)
|
||||
# if subdir metadata and name exists, use that as the model name going forward
|
||||
if "subdir" in asset_info.metadata and asset_info.name:
|
||||
# if no matching parent paths, then something went wrong and should return None
|
||||
if len(parent_paths) == 0:
|
||||
return None
|
||||
relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name)
|
||||
# now we have the parent keys, we can try to get the local path
|
||||
chosen_parent = None
|
||||
full_path = None
|
||||
for parent_path in parent_paths:
|
||||
full_path = folder_paths.get_full_path(parent_path, relative_path)
|
||||
if full_path:
|
||||
chosen_parent = parent_path
|
||||
break
|
||||
if full_path is not None:
|
||||
logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}")
|
||||
# we know the path, so load the model and return it
|
||||
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
||||
# TODO: handle caching
|
||||
return ModelReturnedAsset(state_dict, metadata)
|
||||
# if just name exists, try to find model by name in all subdirs of parent paths
|
||||
# TODO: this behavior should be configurable by user
|
||||
if asset_info.name:
|
||||
for parent_path in parent_paths:
|
||||
filelist = folder_paths.get_filename_list(parent_path)
|
||||
for file in filelist:
|
||||
if os.path.basename(file) == asset_info.name:
|
||||
full_path = folder_paths.get_full_path(parent_path, file)
|
||||
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
||||
# TODO: handle caching
|
||||
return ModelReturnedAsset(state_dict, metadata)
|
||||
# TODO: if download_url metadata exists, download the model and load it; this should be configurable by user
|
||||
if asset_info.metadata.get("download_url", None):
|
||||
...
|
||||
return None
|
||||
|
||||
|
||||
resolvers: list[AssetResolverABC] = []
|
||||
resolvers.append(LocalAssetResolver())
|
||||
|
||||
|
||||
def resolve(asset_info: AssetInfo) -> Any:
|
||||
global resolvers
|
||||
for resolver in resolvers:
|
||||
try:
|
||||
to_return = resolver.resolve(asset_info)
|
||||
if to_return is not None:
|
||||
return resolver.resolve(asset_info)
|
||||
except Exception as e:
|
||||
logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}")
|
||||
return None
|
||||
@ -211,8 +211,8 @@ parser.add_argument(
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -29,6 +29,7 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
from app.assets_manager import populate_db_with_asset
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -102,12 +103,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
try:
|
||||
from app.model_processor import model_processor
|
||||
model_processor.process_file(ckpt)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {ckpt}: {e}")
|
||||
|
||||
populate_db_with_asset(ckpt)
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
import comfy.asset_management
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
class AssetTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AssetTestNode",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.Clip.Output(),
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name: str):
|
||||
hash = None
|
||||
# lets get the full path just so we can retrieve the hash from db, if exists
|
||||
try:
|
||||
full_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
if full_path is None:
|
||||
raise Exception(f"Model {ckpt_name} not found")
|
||||
from app.model_processor import model_processor
|
||||
hash = model_processor.retrieve_hash(full_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get model by hash with error: {e}")
|
||||
subdir, name = os.path.split(ckpt_name)
|
||||
asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir})
|
||||
asset = comfy.asset_management.resolve(asset_info)
|
||||
# /\ the stuff above should happen in execution code instead of inside the node
|
||||
# \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?)
|
||||
if asset is None:
|
||||
raise Exception(f"Model {asset_info.name} not found")
|
||||
assert isinstance(asset, comfy.asset_management.ModelReturnedAsset)
|
||||
out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata)
|
||||
return io.NodeOutput(out[0], out[1], out[2])
|
||||
|
||||
|
||||
class AssetTestExtension(ComfyExtension):
|
||||
@classmethod
|
||||
async def get_node_list(cls):
|
||||
return [AssetTestNode]
|
||||
|
||||
|
||||
def comfy_entrypoint():
|
||||
return AssetTestExtension()
|
||||
17
main.py
17
main.py
@ -278,13 +278,13 @@ def cleanup_temp():
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
async def setup_database():
|
||||
from app import init_db_engine, sync_seed_assets
|
||||
|
||||
await init_db_engine()
|
||||
if not args.disable_assets_autoscan:
|
||||
await sync_seed_assets(["models"])
|
||||
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
"""
|
||||
@ -309,6 +309,8 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
asyncio_loop.run_until_complete(setup_database())
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
@ -317,7 +319,6 @@ def start_comfyui(asyncio_loop=None):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2321,7 +2321,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_assets_test.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -20,6 +20,7 @@ tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
aiosqlite
|
||||
av>=14.2.0
|
||||
blake3
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app import sync_seed_assets, register_assets_system
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
@ -183,6 +184,7 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -627,6 +629,7 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
await sync_seed_assets(["models"])
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
|
||||
307
tests-assets/conftest.py
Normal file
307
tests-assets/conftest.py
Normal file
@ -0,0 +1,307 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable, Optional
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""
|
||||
Allow overriding the database URL used by the spawned ComfyUI process.
|
||||
Priority:
|
||||
1) --db-url command line option
|
||||
2) ASSETS_TEST_DB_URL environment variable (used by CI)
|
||||
3) default: sqlite in-memory
|
||||
"""
|
||||
parser.addoption(
|
||||
"--db-url",
|
||||
action="store",
|
||||
default=os.environ.get("ASSETS_TEST_DB_URL", "sqlite+aiosqlite:///:memory:"),
|
||||
help="Async SQLAlchemy DB URL (e.g. sqlite+aiosqlite:///:memory: or postgresql+psycopg://user:pass@host/db)",
|
||||
)
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _make_base_dirs(root: Path) -> None:
|
||||
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
|
||||
(root / sub).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
async def _wait_http_ready(base: str, session: aiohttp.ClientSession, timeout: float = 90.0) -> None:
|
||||
start = time.time()
|
||||
last_err = None
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
async with session.get(base + "/api/assets") as r:
|
||||
if r.status in (200, 400):
|
||||
return
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
await asyncio.sleep(0.25)
|
||||
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_tmp_base_dir() -> Path:
|
||||
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
|
||||
created_by_fixture = False
|
||||
if env_base:
|
||||
tmp = Path(env_base)
|
||||
tmp.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
|
||||
created_by_fixture = True
|
||||
_make_base_dirs(tmp)
|
||||
yield tmp
|
||||
if created_by_fixture:
|
||||
with contextlib.suppress(Exception):
|
||||
for p in sorted(tmp.rglob("*"), reverse=True):
|
||||
if p.is_file() or p.is_symlink():
|
||||
p.unlink(missing_ok=True)
|
||||
for p in sorted(tmp.glob("**/*"), reverse=True):
|
||||
with contextlib.suppress(Exception):
|
||||
p.rmdir()
|
||||
tmp.rmdir()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
|
||||
"""
|
||||
Boot ComfyUI subprocess with:
|
||||
- sandbox base dir
|
||||
- sqlite memory DB (default)
|
||||
- autoscan disabled
|
||||
Returns (base_url, process, port)
|
||||
"""
|
||||
port = _free_port()
|
||||
db_url = request.config.getoption("--db-url")
|
||||
|
||||
logs_dir = comfy_tmp_base_dir / "logs"
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
|
||||
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
|
||||
|
||||
comfy_root = Path(__file__).resolve().parent.parent
|
||||
if not (comfy_root / "main.py").is_file():
|
||||
raise FileNotFoundError(f"main.py not found under {comfy_root}")
|
||||
|
||||
proc = subprocess.Popen(
|
||||
args=[
|
||||
sys.executable,
|
||||
"main.py",
|
||||
f"--base-directory={str(comfy_tmp_base_dir)}",
|
||||
f"--database-url={db_url}",
|
||||
"--disable-assets-autoscan",
|
||||
"--listen",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--cpu",
|
||||
],
|
||||
stdout=out_log,
|
||||
stderr=err_log,
|
||||
cwd=str(comfy_root),
|
||||
env={**os.environ},
|
||||
)
|
||||
|
||||
for _ in range(50):
|
||||
if proc.poll() is not None:
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
|
||||
time.sleep(0.1)
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
try:
|
||||
async def _probe():
|
||||
async with aiohttp.ClientSession() as s:
|
||||
await _wait_http_ready(base_url, s, timeout=90.0)
|
||||
|
||||
asyncio.run(_probe())
|
||||
yield base_url, proc, port
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
with contextlib.suppress(Exception):
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI did not become ready: {e}")
|
||||
|
||||
if proc and proc.poll() is None:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=15)
|
||||
out_log.close()
|
||||
err_log.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def http() -> AsyncIterator[aiohttp.ClientSession]:
|
||||
timeout = aiohttp.ClientTimeout(total=120)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as s:
|
||||
yield s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_base(comfy_url_and_proc) -> str:
|
||||
base_url, _proc, _port = comfy_url_and_proc
|
||||
return base_url
|
||||
|
||||
|
||||
async def _post_multipart_asset(
|
||||
session: aiohttp.ClientSession,
|
||||
base: str,
|
||||
*,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
meta: dict,
|
||||
data: bytes,
|
||||
extra_fields: Optional[dict] = None,
|
||||
) -> tuple[int, dict]:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(tags))
|
||||
form.add_field("name", name)
|
||||
form.add_field("user_metadata", json.dumps(meta))
|
||||
if extra_fields:
|
||||
for k, v in extra_fields.items():
|
||||
form.add_field(k, v)
|
||||
async with session.post(base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
return r.status, body
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
return _make
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def asset_factory(http: aiohttp.ClientSession, api_base: str):
|
||||
"""
|
||||
Returns create(name, tags, meta, data) -> response dict
|
||||
Tracks created ids and deletes them after the test.
|
||||
"""
|
||||
created: list[str] = []
|
||||
|
||||
async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
|
||||
status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
|
||||
assert status in (200, 201), body
|
||||
created.append(body["id"])
|
||||
return body
|
||||
|
||||
yield create
|
||||
|
||||
# cleanup by id
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as r:
|
||||
await r.read()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seeded_asset(request: pytest.FixtureRequest, http: aiohttp.ClientSession, api_base: str) -> dict:
|
||||
"""
|
||||
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
|
||||
Returns response dict with id, asset_hash, tags, etc.
|
||||
"""
|
||||
name = "unit_1_example.safetensors"
|
||||
p = getattr(request, "param", {}) or {}
|
||||
tags: Optional[list[str]] = p.get("tags")
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 4096, filename=name, content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(tags))
|
||||
form.add_field("name", name)
|
||||
form.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 201, body
|
||||
return body
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str):
|
||||
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
|
||||
yield
|
||||
|
||||
while True:
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
if r.status != 200:
|
||||
break
|
||||
ids = [a["id"] for a in body.get("assets", [])]
|
||||
if not ids:
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
|
||||
await dr.read()
|
||||
|
||||
|
||||
async def trigger_sync_seed_assets(session: aiohttp.ClientSession, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the ComfyUI '/object_info' endpoint."""
|
||||
async with session.post(base_url + "/api/assets/scan/seed", json={"roots": ["models", "input", "output"]}) as r:
|
||||
await r.read()
|
||||
await asyncio.sleep(0.1) # tiny yield to the event loop to let any final DB commits flush
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def run_scan_and_wait(http: aiohttp.ClientSession, api_base: str):
|
||||
"""Schedule an asset scan for a given root and wait until it finishes."""
|
||||
async def _run(root: str, timeout: float = 120.0):
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r:
|
||||
# we ignore body; scheduling returns 202 with a status payload
|
||||
await r.read()
|
||||
|
||||
start = time.time()
|
||||
while True:
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as st:
|
||||
body = await st.json()
|
||||
scans = (body or {}).get("scans", [])
|
||||
status = None
|
||||
if scans:
|
||||
status = scans[-1].get("status")
|
||||
if status in {"completed", "failed", "cancelled"}:
|
||||
if status != "completed":
|
||||
raise RuntimeError(f"Scan for root={root} finished with status={status}")
|
||||
return
|
||||
if time.time() - start > timeout:
|
||||
raise TimeoutError(f"Timed out waiting for scan of root={root}")
|
||||
await asyncio.sleep(0.1)
|
||||
return _run
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
347
tests-assets/test_assets_missing_sync.py
Normal file
347
tests-assets/test_assets_missing_sync.py
Normal file
@ -0,0 +1,347 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_seed_asset_removed_when_file_is_deleted(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Asset without hash (seed) whose file disappears:
|
||||
after triggering sync_seed_assets, Asset + AssetInfo disappear.
|
||||
"""
|
||||
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
|
||||
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
|
||||
case_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
fp = case_dir / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
# there should be exactly one with that name
|
||||
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
|
||||
assert matches
|
||||
assert matches[0].get("asset_hash") is None
|
||||
asset_info_id = matches[0]["id"]
|
||||
|
||||
# Remove the underlying file and sync again
|
||||
if fp.exists():
|
||||
fp.unlink()
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
|
||||
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashed_asset_missing_tag_added_then_removed_after_scan(
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Hashed asset with a single cache_state:
|
||||
1. delete its file -> sync adds 'missing'
|
||||
2. restore file -> scan removes 'missing'
|
||||
"""
|
||||
name = "missing_tag_test.png"
|
||||
tags = ["input", "unit-tests", "msync2"]
|
||||
data = make_asset_bytes(name, 4096)
|
||||
a = await asset_factory(name, tags, {}, data)
|
||||
|
||||
# Compute its on-disk path and remove it
|
||||
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
|
||||
assert dest.exists(), f"Expected asset file at {dest}"
|
||||
dest.unlink()
|
||||
|
||||
# Fast sync should add 'missing' to the AssetInfo
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{a['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
|
||||
|
||||
# Restore the file with the exact same content and re-hash/verify via scan
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(data)
|
||||
|
||||
await run_scan_and_wait("input")
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{a['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashed_asset_two_asset_infos_both_get_missing(
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Hashed asset with a single cache_state, but two AssetInfo rows:
|
||||
deleting the single file then syncing should add 'missing' to both infos.
|
||||
"""
|
||||
# Upload one hashed asset
|
||||
name = "two_infos_one_path.png"
|
||||
base_tags = ["input", "unit-tests", "multiinfo"]
|
||||
created = await asset_factory(name, base_tags, {}, b"A" * 2048)
|
||||
|
||||
# Create second AssetInfo for the same Asset via from-hash
|
||||
payload = {
|
||||
"hash": created["asset_hash"],
|
||||
"name": "two_infos_one_path_copy.png",
|
||||
"tags": base_tags, # keep it in our unit-tests scope for cleanup
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 201, b2
|
||||
second_id = b2["id"]
|
||||
|
||||
# Remove the single underlying file
|
||||
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0:
|
||||
tags0 = await r0.json()
|
||||
assert r0.status == 200, tags0
|
||||
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
|
||||
old_missing = int(byname0.get("missing", {}).get("count", 0))
|
||||
|
||||
# Sync -> both AssetInfos for this asset must receive 'missing'
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as ga:
|
||||
da = await ga.json()
|
||||
assert ga.status == 200, da
|
||||
assert "missing" in set(da.get("tags", []))
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{second_id}") as gb:
|
||||
db = await gb.json()
|
||||
assert gb.status == 200, db
|
||||
assert "missing" in set(db.get("tags", []))
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
|
||||
tags1 = await r1.json()
|
||||
assert r1.status == 200, tags1
|
||||
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
|
||||
new_missing = int(byname1.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Hashed asset with two cache_state rows:
|
||||
1. delete one file -> sync should NOT add 'missing'
|
||||
2. delete second file -> sync should add 'missing'
|
||||
"""
|
||||
name = "two_cache_states_partial_delete.png"
|
||||
tags = ["input", "unit-tests", "dual"]
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
created = await asset_factory(name, tags, {}, data)
|
||||
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
|
||||
assert path1.exists()
|
||||
|
||||
# Create a second on-disk copy under the same root but different subfolder
|
||||
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
|
||||
# Fast seed so the second path appears (as a seed initially)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
|
||||
await run_scan_and_wait("input")
|
||||
|
||||
# Remove only one file and sync -> asset should still be healthy (no 'missing')
|
||||
path1.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
|
||||
|
||||
# Baseline 'missing' usage count just before last file removal
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r0:
|
||||
tags0 = await r0.json()
|
||||
assert r0.status == 200, tags0
|
||||
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
|
||||
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
|
||||
path2.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
|
||||
tags1 = await r1.json()
|
||||
assert r1.status == 200, tags1
|
||||
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Fast pass alone clears 'missing' when size and mtime match exactly:
|
||||
1) upload (hashed), record original mtime_ns
|
||||
2) delete -> fast pass adds 'missing'
|
||||
3) restore same bytes and set mtime back to the original value
|
||||
4) run fast pass again -> 'missing' is removed (no slow scan)
|
||||
"""
|
||||
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
|
||||
name = "fastpass_clear.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
st0 = p.stat()
|
||||
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
|
||||
|
||||
# Delete -> fast pass adds 'missing'
|
||||
p.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
|
||||
# Restore same bytes and revert mtime to the original value
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(data)
|
||||
# set both atime and mtime in ns to ensure exact match
|
||||
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
|
||||
|
||||
# Fast pass should clear 'missing' without a scan
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_fastpass_removes_stale_state_row_no_missing(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two states:
|
||||
- delete one file
|
||||
- run fast pass only
|
||||
Expect:
|
||||
- asset stays healthy (no 'missing')
|
||||
- stale AssetCacheState row for the deleted path is removed.
|
||||
We verify this behaviorally by recreating the deleted path and running fast pass again:
|
||||
a new *seed* AssetInfo is created, which proves the old state row was not reused.
|
||||
"""
|
||||
scope = f"stale-{uuid.uuid4().hex[:6]}"
|
||||
name = "two_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload hashed asset at path1
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
|
||||
p1 = base / a1_filename
|
||||
assert p1.exists()
|
||||
|
||||
aid = a["id"]
|
||||
h = a["asset_hash"]
|
||||
|
||||
# Create second state path2, seed+scan to dedupe into the same Asset
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
|
||||
p1.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", []))
|
||||
|
||||
# Recreate path1 and run fast pass again.
|
||||
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
|
||||
p1.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
) as rl:
|
||||
bl = await rl.json()
|
||||
assert rl.status == 200, bl
|
||||
items = bl.get("assets", [])
|
||||
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
|
||||
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
|
||||
assert h in hashes
|
||||
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
|
||||
|
||||
# Asset identity still healthy
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
|
||||
assert rh.status == 200
|
||||
316
tests-assets/test_crud.py
Normal file
316
tests-assets/test_crud.py
Normal file
@ -0,0 +1,316 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_success(
|
||||
http: aiohttp.ClientSession, api_base: str, seeded_asset: dict
|
||||
):
|
||||
h = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "from_hash_ok.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
assert b1["asset_hash"] == h
|
||||
assert b1["created_new"] is False
|
||||
aid = b1["id"]
|
||||
|
||||
# Calling again with the same name should return the same AssetInfo id
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 201, b2
|
||||
assert b2["id"] == aid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_and_delete_asset(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# GET detail
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
detail = await rg.json()
|
||||
assert rg.status == 200, detail
|
||||
assert detail["id"] == aid
|
||||
assert "user_metadata" in detail
|
||||
assert "filename" in detail["user_metadata"]
|
||||
|
||||
# DELETE
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as rd:
|
||||
assert rd.status == 204
|
||||
|
||||
# GET again -> 404
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
|
||||
body = await rg2.json()
|
||||
assert rg2.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_upon_reference_count(
|
||||
http: aiohttp.ClientSession, api_base: str, seeded_asset: dict
|
||||
):
|
||||
# Create a second reference to the same asset via from-hash
|
||||
src_hash = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": src_hash,
|
||||
"name": "unit_ref_copy.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
|
||||
"user_metadata": {"note": "copy"},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r2:
|
||||
copy = await r2.json()
|
||||
assert r2.status == 201, copy
|
||||
assert copy["asset_hash"] == src_hash
|
||||
assert copy["created_new"] is False
|
||||
|
||||
# Delete original reference -> asset identity must remain
|
||||
aid1 = seeded_asset["id"]
|
||||
async with http.delete(f"{api_base}/api/assets/{aid1}") as rd1:
|
||||
assert rd1.status == 204
|
||||
|
||||
async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh1:
|
||||
assert rh1.status == 200 # identity still present
|
||||
|
||||
# Delete the last reference with default semantics -> identity and cached files removed
|
||||
aid2 = copy["id"]
|
||||
async with http.delete(f"{api_base}/api/assets/{aid2}") as rd2:
|
||||
assert rd2.status == 204
|
||||
|
||||
async with http.head(f"{api_base}/api/assets/hash/{src_hash}") as rh2:
|
||||
assert rh2.status == 404 # orphan content removed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_asset_fields(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
payload = {
|
||||
"name": "unit_1_renamed.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "beta"],
|
||||
"user_metadata": {"purpose": "updated", "epoch": 2},
|
||||
}
|
||||
async with http.put(f"{api_base}/api/assets/{aid}", json=payload) as ru:
|
||||
body = await ru.json()
|
||||
assert ru.status == 200, body
|
||||
assert body["name"] == payload["name"]
|
||||
assert "beta" in body["tags"]
|
||||
assert body["user_metadata"]["purpose"] == "updated"
|
||||
# filename should still be present and normalized by server
|
||||
assert "filename" in body["user_metadata"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Existing
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh1:
|
||||
assert rh1.status == 200
|
||||
|
||||
# Non-existent
|
||||
async with http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}") as rh2:
|
||||
assert rh2.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str):
|
||||
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||
# aiohttp exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh:
|
||||
assert rh.status == 400
|
||||
body = await rh.read()
|
||||
assert body == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
|
||||
bogus = str(uuid.uuid4())
|
||||
async with http.delete(f"{api_base}/api/assets/{bogus}") as r:
|
||||
body = await r.json()
|
||||
assert r.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_invalids(http: aiohttp.ClientSession, api_base: str):
|
||||
# Bad hash algorithm
|
||||
bad = {
|
||||
"hash": "sha256:" + "0" * 64,
|
||||
"name": "x.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=bad) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Invalid JSON body
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}") as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_update_download_bad_ids(http: aiohttp.ClientSession, api_base: str):
|
||||
# All endpoints should be not found, as we UUID regex directly in the route definition.
|
||||
bad_id = "not-a-uuid"
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{bad_id}") as r1:
|
||||
assert r1.status == 404
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{bad_id}/content") as r3:
|
||||
assert r3.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
async with http.put(f"{api_base}/api/assets/{aid}", json={}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_concurrent_delete_same_asset_info_single_204(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Many concurrent DELETE for the same AssetInfo should result in:
|
||||
- exactly one 204 No Content (the one that actually deleted)
|
||||
- all others 404 Not Found (row already gone)
|
||||
"""
|
||||
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
|
||||
name = "to_delete.bin"
|
||||
data = make_asset_bytes(name, 1536)
|
||||
|
||||
created = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = created["id"]
|
||||
|
||||
# Hit the same endpoint N times in parallel.
|
||||
n_tests = 4
|
||||
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||
tasks = [asyncio.create_task(http.delete(url)) for _ in range(n_tests)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
statuses = [r.status for r in responses]
|
||||
# Drain bodies to close connections (optional but avoids warnings).
|
||||
await asyncio.gather(*[r.read() for r in responses])
|
||||
|
||||
# Exactly one actual delete, the rest must be 404
|
||||
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
|
||||
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
|
||||
|
||||
# The resource must be gone.
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
assert rg.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
|
||||
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_filename.bin"
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
fp = base / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
) as r1:
|
||||
body = await r1.json()
|
||||
assert r1.status == 200, body
|
||||
matches = [a for a in body.get("assets", []) if a.get("name") == name]
|
||||
assert matches, "Seed asset should be visible after sync"
|
||||
assert matches[0].get("asset_hash") is None # still a seed
|
||||
aid = matches[0]["id"]
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as r2:
|
||||
detail = await r2.json()
|
||||
assert r2.status == 200, detail
|
||||
filename = (detail.get("user_metadata") or {}).get("filename")
|
||||
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
|
||||
assert filename == expected, f"expected filename={expected}, got {filename!r}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_metadata_filename_computed_and_updated_on_retarget(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
|
||||
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
|
||||
run fast pass + scan -> filename updates to new relative path.
|
||||
"""
|
||||
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
|
||||
name1 = "compute_metadata_filename.png"
|
||||
name2 = "compute_changed_metadata_filename.png"
|
||||
data = make_asset_bytes(name1, 2100)
|
||||
|
||||
# Upload into nested path a/b
|
||||
a = await asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
root_base = comfy_tmp_base_dir / root
|
||||
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
|
||||
assert p1.exists()
|
||||
|
||||
# filename at ingest should be the path relative to root
|
||||
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
fn1 = d1["user_metadata"].get("filename")
|
||||
assert fn1 == rel1
|
||||
|
||||
# Retarget: copy to x/<name2>, remove old, then sync+scan
|
||||
p2 = root_base / "unit-tests" / scope / "x" / name2
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
if p1.exists():
|
||||
p1.unlink()
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base) # seed the new path
|
||||
await run_scan_and_wait(root) # verify/hash and reconcile
|
||||
|
||||
# filename should now point at x/<name2>
|
||||
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
fn2 = d2["user_metadata"].get("filename")
|
||||
assert fn2 == rel2
|
||||
168
tests-assets/test_downloads.py
Normal file
168
tests-assets/test_downloads.py
Normal file
@ -0,0 +1,168 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_attachment_and_inline(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# default attachment
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r1:
|
||||
data = await r1.read()
|
||||
assert r1.status == 200
|
||||
cd = r1.headers.get("Content-Disposition", "")
|
||||
assert "attachment" in cd
|
||||
assert data and len(data) == 4096
|
||||
|
||||
# inline requested
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline") as r2:
|
||||
await r2.read()
|
||||
assert r2.status == 200
|
||||
cd2 = r2.headers.get("Content-Disposition", "")
|
||||
assert "inline" in cd2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_download_chooses_existing_state_and_updates_access_time(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two state paths: if the first one disappears,
|
||||
GET /content still serves from the remaining path and bumps last_access_time.
|
||||
"""
|
||||
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
|
||||
name = "first_existing_state.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
# Upload -> path1
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert path1.exists()
|
||||
|
||||
# Seed path2 by copying, then scan to dedupe into a second state
|
||||
path2 = base / "alt" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# Remove path1 so server must fall back to path2
|
||||
path1.unlink()
|
||||
|
||||
# last_access_time before
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg0:
|
||||
d0 = await rg0.json()
|
||||
assert rg0.status == 200, d0
|
||||
ts0 = d0.get("last_access_time")
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r:
|
||||
blob = await r.read()
|
||||
assert r.status == 200
|
||||
assert blob == data # must serve from the surviving state (same bytes)
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg1:
|
||||
d1 = await rg1.json()
|
||||
assert rg1.status == 200, d1
|
||||
ts1 = d1.get("last_access_time")
|
||||
|
||||
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
|
||||
if not s:
|
||||
return None
|
||||
s = s[:-1] if s.endswith("Z") else s
|
||||
return datetime.fromisoformat(s).timestamp()
|
||||
|
||||
t0 = _parse_iso8601(ts0)
|
||||
t1 = _parse_iso8601(ts1)
|
||||
assert t1 is not None
|
||||
if t0 is not None:
|
||||
assert t1 > t0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
|
||||
async def test_download_missing_file_returns_404(
|
||||
http: aiohttp.ClientSession, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
|
||||
):
|
||||
# Remove the underlying file then attempt download.
|
||||
# We initialize fixture without additional tags to know exactly the asset file path.
|
||||
try:
|
||||
aid = seeded_asset["id"]
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
detail = await rg.json()
|
||||
assert rg.status == 200
|
||||
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
|
||||
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
|
||||
assert abs_path.exists()
|
||||
abs_path.unlink()
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r2:
|
||||
assert r2.status == 404
|
||||
body = await r2.json()
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
finally:
|
||||
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
|
||||
await dr.read()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_download_404_if_all_states_missing(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
|
||||
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
|
||||
name = "missing_all_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload -> path1
|
||||
a = await asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert p1.exists()
|
||||
|
||||
# Seed a second state and dedupe
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# Remove first file -> download should still work via the second state
|
||||
p1.unlink()
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as ok1:
|
||||
b1 = await ok1.read()
|
||||
assert ok1.status == 200 and b1 == data
|
||||
|
||||
# Remove the last file -> download must 404
|
||||
p2.unlink()
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as r2:
|
||||
body = await r2.json()
|
||||
assert r2.status == 404
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
337
tests-assets/test_list_filter.py
Normal file
337
tests-assets/test_list_filter.py
Normal file
@ -0,0 +1,337 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_paging_and_sort(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
|
||||
for n in names:
|
||||
await asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", "paging"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
|
||||
# name ascending for stable order
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == sorted(names)[:2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == sorted(names)[2:]
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.ClientSession, api_base: str, asset_factory):
|
||||
a = await asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||
b = await asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "name_contains": "inc_"},
|
||||
) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in body2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert b["name"] in names2
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "non-existing-tag"},
|
||||
) as r2:
|
||||
body3 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert not body3["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-size"]
|
||||
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
|
||||
await asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
|
||||
await asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
|
||||
await asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert names[:3] == [n1, n2, n3]
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
names2 = [a["name"] for a in b2["assets"]]
|
||||
assert names2[:3] == [n3, n2, n1]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
|
||||
a1 = await asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
|
||||
a2 = await asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
|
||||
|
||||
# Rename the second asset to bump updated_at
|
||||
async with http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}) as rp:
|
||||
upd = await rp.json()
|
||||
assert rp.status == 200, upd
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == "upd_b_renamed.safetensors"
|
||||
assert a1["name"] in names
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-access"]
|
||||
await asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
|
||||
await asyncio.sleep(0.02)
|
||||
a2 = await asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
|
||||
|
||||
# Touch last_access_time of b by downloading its content
|
||||
await asyncio.sleep(0.02)
|
||||
async with http.get(f"{api_base}/api/assets/{a2['id']}/content") as dl:
|
||||
assert dl.status == 200
|
||||
await dl.read()
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == a2["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-include"]
|
||||
a = await asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
|
||||
await asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
|
||||
|
||||
# CSV + case-insensitive
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
assert not any("beta" in x for x in names1)
|
||||
|
||||
# Repeated query params for include_tags
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-include"),
|
||||
("include_tags", "alpha"),
|
||||
]
|
||||
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert not any("beta" in x for x in names2)
|
||||
|
||||
# Duplicates and spaces in CSV
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a["name"] in names3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
|
||||
a = await asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
|
||||
await asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
|
||||
|
||||
# Exclude uppercase should work
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
# Repeated excludes with duplicates
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-exclude"),
|
||||
("exclude_tags", "beta"),
|
||||
("exclude_tags", "beta"),
|
||||
]
|
||||
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert all("beta" not in x for x in names2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-name"]
|
||||
a1 = await asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
|
||||
a2 = await asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a1["name"] in names1
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a1["name"] in names2
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a2["name"] in names3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
|
||||
await asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
|
||||
await asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
|
||||
await asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
|
||||
|
||||
# Offset far beyond total
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
assert not b1["assets"]
|
||||
assert b1["has_more"] is False
|
||||
|
||||
# Boundary large limit (<=500 is valid)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert len(b2["assets"]) == 3
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
|
||||
async with http.get(api_base + "/api/assets", params={"offset": "-1"}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
async with http.get(api_base + "/api/assets", params={"limit": "abc"}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||
# limit too small
|
||||
async with http.get(api_base + "/api/assets", params={"limit": "0"}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
# bad metadata JSON
|
||||
async with http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_assets_name_contains_literal_underscore(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
|
||||
We create:
|
||||
- foo_bar.safetensors (should match)
|
||||
- fooxbar.safetensors (must NOT match if '_' is escaped)
|
||||
- foobar.safetensors (must NOT match)
|
||||
"""
|
||||
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
|
||||
tags = ["models", "checkpoints", "unit-tests", scope]
|
||||
|
||||
a = await asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
|
||||
b = await asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
|
||||
c = await asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
|
||||
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
|
||||
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
|
||||
assert body["total"] == 1
|
||||
387
tests-assets/test_metadata_filters.py
Normal file
387
tests-assets/test_metadata_filters.py
Normal file
@ -0,0 +1,387 @@
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_and_across_keys_and_types(
|
||||
http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes
|
||||
):
|
||||
name = "mf_and_mix.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
|
||||
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
|
||||
|
||||
# All keys must match (AND semantics)
|
||||
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_ok),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert name in names
|
||||
|
||||
# One key mismatched -> no result
|
||||
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_bad),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_types.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
|
||||
meta = {"epoch": 1, "active": True}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name))
|
||||
|
||||
# int filter matches numeric
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# string "1" must NOT match numeric 1
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": "1"}),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
# bool True matches, string "true" must NOT match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": True}),
|
||||
},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": "true"}),
|
||||
},
|
||||
) as r4:
|
||||
b4 = await r4.json()
|
||||
assert r4.status == 200 and not b4["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_scalars.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
|
||||
meta = {"flags": ["red", "green"]}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
|
||||
|
||||
# Any-of should match because "green" is present
|
||||
filt_ok = {"flags": ["blue", "green"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# None of provided flags present -> no match
|
||||
filt_miss = {"flags": ["blue", "yellow"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
# Duplicates in list should not break matching
|
||||
filt_dup = {"flags": ["green", "green", "green"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||
http, api_base, asset_factory, make_asset_bytes
|
||||
):
|
||||
# a1: key missing; a2: explicit null; a3: concrete value
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-none"]
|
||||
a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
|
||||
a2 = await asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
|
||||
a3 = await asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
|
||||
|
||||
# Filter {maybe: None} must match a1 and a2, not a3
|
||||
filt = {"maybe": None}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
got = [a["name"] for a in b1["assets"]]
|
||||
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
|
||||
|
||||
# Any-of with None should include missing/null plus value matches
|
||||
filt_any = {"maybe": [None, "x"]}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_nested_json.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
|
||||
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
|
||||
await asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
|
||||
|
||||
# Exact JSON object equality (same structure)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": cfg}),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Different JSON object should not match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_objects.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
|
||||
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
|
||||
await asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
|
||||
|
||||
# Any-of for list of objects should match when one element equals the filter object
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
|
||||
},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Non-matching object -> no match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
|
||||
},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_keys_unicode.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
|
||||
meta = {
|
||||
"weird.key": "v1",
|
||||
"path/like": 7,
|
||||
"with:colon": True,
|
||||
"ключ": "значение",
|
||||
"emoji": "🐍",
|
||||
}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
|
||||
|
||||
# Match all the special keys
|
||||
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Unicode key match
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and any(a["name"] == name for a in b2["assets"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
|
||||
a0 = await asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
|
||||
a1 = await asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
|
||||
|
||||
# count == 0 must match only a0
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names1 = [a["name"] for a in b1["assets"]]
|
||||
assert a0["name"] in names1 and a1["name"] not in names1
|
||||
|
||||
# Any-of list of booleans: True matches second asset
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_mixed_list.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
|
||||
meta = {"mix": ["1", 1, True, None]}
|
||||
await asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
|
||||
|
||||
# Should match because 1 is present
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Should NOT match for False
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Use a unique scope tag to avoid interference
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||
x = await asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
|
||||
y = await asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
|
||||
|
||||
# Filtering by unknown key with None should return both (missing key OR null)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
|
||||
) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = {a["name"] for a in b1["assets"]}
|
||||
assert x["name"] in names and y["name"] in names
|
||||
|
||||
# Filtering by unknown key with concrete value should return none
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200 and not b2["assets"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
|
||||
# alpha matches epoch=1; beta has epoch=2
|
||||
a = await asset_factory(
|
||||
"mf_tag_alpha.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes("alpha"),
|
||||
)
|
||||
b = await asset_factory(
|
||||
"mf_tag_beta.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||
{"epoch": 2},
|
||||
make_asset_bytes("beta"),
|
||||
)
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,mf-tag,alpha",
|
||||
"exclude_tags": "beta",
|
||||
"name_contains": "mf_tag_",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
}
|
||||
async with http.get(api_base + "/api/assets", params=params) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Three assets in same scope with different sizes and a common filter key
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
|
||||
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
|
||||
await asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
|
||||
await asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
|
||||
await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
|
||||
|
||||
# Sort by size ascending with paging
|
||||
q = {
|
||||
"include_tags": "unit-tests,mf-sort",
|
||||
"metadata_filter": json.dumps({"group": "g"}),
|
||||
"sort": "size", "order": "asc", "limit": "2",
|
||||
}
|
||||
async with http.get(api_base + "/api/assets", params=q) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == [n1, n2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
q2 = {**q, "offset": "2"}
|
||||
async with http.get(api_base + "/api/assets", params=q2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == [n3]
|
||||
assert b2["has_more"] is False
|
||||
510
tests-assets/test_scans.py
Normal file
510
tests-assets/test_scans.py
Normal file
@ -0,0 +1,510 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def _base_for(root: str, comfy_tmp_base_dir: Path) -> Path:
|
||||
assert root in ("input", "output")
|
||||
return comfy_tmp_base_dir / root
|
||||
|
||||
|
||||
def _mkbytes(label: str, size: int) -> bytes:
|
||||
seed = sum(label.encode("utf-8")) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_scan_schedule_idempotent_while_running(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Idempotent schedule while running."""
|
||||
scope = f"idem-{uuid.uuid4().hex[:6]}"
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create several seed files (non-zero) to ensure the scan runs long enough
|
||||
for i in range(8):
|
||||
(base / f"f{i}.bin").write_bytes(_mkbytes(f"{scope}-{i}", 2 * 1024 * 1024)) # ~2 MiB each
|
||||
|
||||
# Seed -> states with hash=NULL
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Schedule once
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 202, b1
|
||||
scans1 = {s["root"]: s for s in b1.get("scans", [])}
|
||||
s1 = scans1.get(root)
|
||||
assert s1 and s1["status"] in {"scheduled", "running"}
|
||||
sid1 = s1["scan_id"]
|
||||
|
||||
# Schedule again immediately — must return the same scan entry (no new worker)
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 202, b2
|
||||
scans2 = {s["root"]: s for s in b2.get("scans", [])}
|
||||
s2 = scans2.get(root)
|
||||
assert s2 and s2["scan_id"] == sid1
|
||||
|
||||
# Filtered GET must show exactly one scan for this root
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs:
|
||||
bs = await gs.json()
|
||||
assert gs.status == 200, bs
|
||||
scans = bs.get("scans", [])
|
||||
assert len(scans) == 1 and scans[0]["scan_id"] == sid1
|
||||
|
||||
# Let it finish to avoid cross-test interference
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_status_filter_by_root_and_file_errors(
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
run_scan_and_wait,
|
||||
asset_factory,
|
||||
):
|
||||
"""Filtering get scan status by root (schedule for both input and output) + file_errors presence."""
|
||||
# Create one hashed asset in input under a dir we will chmod to 000 to force PermissionError in reconcile stage
|
||||
in_scope = f"filter-in-{uuid.uuid4().hex[:6]}"
|
||||
protected_dir = _base_for("input", comfy_tmp_base_dir) / "unit-tests" / in_scope / "deny"
|
||||
protected_dir.mkdir(parents=True, exist_ok=True)
|
||||
name_in = "protected.bin"
|
||||
|
||||
data = b"A" * 4096
|
||||
await asset_factory(name_in, ["input", "unit-tests", in_scope, "deny"], {}, data)
|
||||
try:
|
||||
os.chmod(protected_dir, 0o000)
|
||||
|
||||
# Also schedule a scan for output root (no errors there)
|
||||
out_scope = f"filter-out-{uuid.uuid4().hex[:6]}"
|
||||
out_dir = _base_for("output", comfy_tmp_base_dir) / "unit-tests" / out_scope
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
(out_dir / "ok.bin").write_bytes(b"B" * 1024)
|
||||
await trigger_sync_seed_assets(http, api_base) # seed output file
|
||||
|
||||
# Schedule both roots
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["input"]}) as r_in:
|
||||
assert r_in.status == 202
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": ["output"]}) as r_out:
|
||||
assert r_out.status == 202
|
||||
|
||||
# Wait both to complete, input last (we want its errors)
|
||||
await run_scan_and_wait("output")
|
||||
await run_scan_and_wait("input")
|
||||
|
||||
# Filter by root=input: only input scan listed and must have file_errors
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": "input"}) as gs:
|
||||
body = await gs.json()
|
||||
assert gs.status == 200, body
|
||||
scans = body.get("scans", [])
|
||||
assert len(scans) == 1
|
||||
errs = scans[0].get("file_errors", [])
|
||||
# Must contain at least one error with a message
|
||||
assert errs and any(e.get("message") for e in errs)
|
||||
finally:
|
||||
# Restore perms so cleanup can remove files/dirs
|
||||
try:
|
||||
os.chmod(protected_dir, 0o755)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
@pytest.mark.skipif(os.name == "nt", reason="Permission-based file_errors are unreliable on Windows")
|
||||
async def test_scan_records_file_errors_permission_denied(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""file_errors recording (permission denied) for input/output"""
|
||||
scope = f"errs-{uuid.uuid4().hex[:6]}"
|
||||
deny_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "deny"
|
||||
deny_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = "deny.bin"
|
||||
|
||||
a1 = await asset_factory(name, [root, "unit-tests", scope, "deny"], {}, b"X" * 2048)
|
||||
asset_filename = get_asset_filename(a1["asset_hash"], ".bin")
|
||||
try:
|
||||
os.chmod(deny_dir, 0o000)
|
||||
async with http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}) as r:
|
||||
assert r.status == 202
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
async with http.get(api_base + "/api/assets/scan", params={"root": root}) as gs:
|
||||
body = await gs.json()
|
||||
assert gs.status == 200, body
|
||||
scans = body.get("scans", [])
|
||||
assert len(scans) == 1
|
||||
errs = scans[0].get("file_errors", [])
|
||||
# Should contain at least one PermissionError-like record
|
||||
assert errs
|
||||
assert any(e.get("path", "").endswith(asset_filename) and e.get("message") for e in errs)
|
||||
finally:
|
||||
try:
|
||||
os.chmod(deny_dir, 0o755)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_missing_tag_created_and_visible_in_tags(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Missing tag appears in tags list and increments count (input/output)"""
|
||||
# Baseline count of 'missing' tag (may be absent)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000"}) as r0:
|
||||
t0 = await r0.json()
|
||||
assert r0.status == 200, t0
|
||||
byname = {t["name"]: t for t in t0.get("tags", [])}
|
||||
old_count = int(byname.get("missing", {}).get("count", 0))
|
||||
|
||||
scope = f"miss-{uuid.uuid4().hex[:6]}"
|
||||
name = "missing_me.bin"
|
||||
created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Y" * 4096)
|
||||
|
||||
# Remove the only file and trigger fast pass
|
||||
p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Asset has 'missing' tag
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert g1.status == 200, d1
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
|
||||
# Tag list now contains 'missing' with increased count
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}) as r1:
|
||||
t1 = await r1.json()
|
||||
assert r1.status == 200, t1
|
||||
byname1 = {t["name"]: t for t in t1.get("tags", [])}
|
||||
assert "missing" in byname1
|
||||
assert int(byname1["missing"]["count"]) >= old_count + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_missing_reapplies_after_manual_removal(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Manual removal of 'missing' does not block automatic re-apply (input/output)"""
|
||||
scope = f"reapply-{uuid.uuid4().hex[:6]}"
|
||||
name = "reapply.bin"
|
||||
created = await asset_factory(name, [root, "unit-tests", scope], {}, b"Z" * 1024)
|
||||
|
||||
# Make it missing
|
||||
p = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(created["asset_hash"], ".bin")
|
||||
p.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Remove the 'missing' tag manually
|
||||
async with http.delete(f"{api_base}/api/assets/{created['id']}/tags", json={"tags": ["missing"]}) as rdel:
|
||||
b = await rdel.json()
|
||||
assert rdel.status == 200, b
|
||||
assert "missing" in set(b.get("removed", []))
|
||||
|
||||
# Next sync must re-add it
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(f"{api_base}/api/assets/{created['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert g2.status == 200, d2
|
||||
assert "missing" in set(d2.get("tags", []))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_delete_one_asset_info_of_missing_asset_keeps_identity(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Delete one AssetInfo of a missing asset while another exists (input/output)"""
|
||||
scope = f"twoinfos-{uuid.uuid4().hex[:6]}"
|
||||
name = "twoinfos.bin"
|
||||
a1 = await asset_factory(name, [root, "unit-tests", scope], {}, b"W" * 2048)
|
||||
|
||||
# Second AssetInfo for the same content under same root (different name to avoid collision)
|
||||
a2 = await asset_factory("copy_" + name, [root, "unit-tests", scope], {}, b"W" * 2048)
|
||||
|
||||
# Remove file of the first (both point to the same Asset, but we know on-disk path name for a1)
|
||||
p1 = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin")
|
||||
p1.unlink()
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Both infos should be marked missing
|
||||
async with http.get(f"{api_base}/api/assets/{a1['id']}") as g1:
|
||||
d1 = await g1.json()
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
async with http.get(f"{api_base}/api/assets/{a2['id']}") as g2:
|
||||
d2 = await g2.json()
|
||||
assert "missing" in set(d2.get("tags", []))
|
||||
|
||||
# Delete one info
|
||||
async with http.delete(f"{api_base}/api/assets/{a1['id']}") as rd:
|
||||
assert rd.status == 204
|
||||
|
||||
# Asset identity still exists (by hash)
|
||||
h = a1["asset_hash"]
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
|
||||
assert rh.status == 200
|
||||
|
||||
# Remaining info still reflects 'missing'
|
||||
async with http.get(f"{api_base}/api/assets/{a2['id']}") as g3:
|
||||
d3 = await g3.json()
|
||||
assert g3.status == 200 and "missing" in set(d3.get("tags", []))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("keep_root", ["input", "output"])
|
||||
async def test_delete_last_asset_info_false_keeps_asset_and_states_multiroot(
|
||||
keep_root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
make_asset_bytes,
|
||||
asset_factory,
|
||||
):
|
||||
"""Delete last AssetInfo with delete_content_if_orphan=false keeps asset and the underlying on-disk content."""
|
||||
other_root = "output" if keep_root == "input" else "input"
|
||||
scope = f"delfalse-{uuid.uuid4().hex[:6]}"
|
||||
data = make_asset_bytes(scope, 3072)
|
||||
|
||||
# First upload creates the physical file
|
||||
a1 = await asset_factory("keep1.bin", [keep_root, "unit-tests", scope], {}, data)
|
||||
# Second upload (other root) is deduped to the same content; no new file on disk
|
||||
a2 = await asset_factory("keep2.bin", [other_root, "unit-tests", scope], {}, data)
|
||||
|
||||
h = a1["asset_hash"]
|
||||
p1 = _base_for(keep_root, comfy_tmp_base_dir) / "unit-tests" / scope / get_asset_filename(h, ".bin")
|
||||
|
||||
# De-dup semantics: only the first physical file exists
|
||||
assert p1.exists(), "Expected the first physical file to exist"
|
||||
|
||||
# Delete both AssetInfos; keep content on the very last delete
|
||||
async with http.delete(f"{api_base}/api/assets/{a2['id']}") as rfirst:
|
||||
assert rfirst.status == 204
|
||||
async with http.delete(f"{api_base}/api/assets/{a1['id']}?delete_content=false") as rlast:
|
||||
assert rlast.status == 204
|
||||
|
||||
# Asset identity remains and physical content is still present
|
||||
async with http.head(f"{api_base}/api/assets/hash/{h}") as rh:
|
||||
assert rh.status == 200
|
||||
assert p1.exists(), "Content file should remain after keep-content delete"
|
||||
|
||||
# Cleanup: re-create a reference by hash and then delete to purge content
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "cleanup.bin",
|
||||
"tags": [keep_root, "unit-tests", scope, "cleanup"],
|
||||
"user_metadata": {},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as rfh:
|
||||
ref = await rfh.json()
|
||||
assert rfh.status == 201, ref
|
||||
cid = ref["id"]
|
||||
async with http.delete(f"{api_base}/api/assets/{cid}") as rdel:
|
||||
assert rdel.status == 204
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_sync_seed_ignores_zero_byte_files(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
scope = f"zero-{uuid.uuid4().hex[:6]}"
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
z = base / "empty.dat"
|
||||
z.write_bytes(b"") # zero bytes
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# No AssetInfo created for this zero-byte file
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests," + scope, "name_contains": "empty.dat"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
assert not [a for a in body.get("assets", []) if a.get("name") == "empty.dat"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_sync_seed_idempotency(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
scope = f"idemseed-{uuid.uuid4().hex[:6]}"
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
files = [f"f{i}.dat" for i in range(3)]
|
||||
for i, n in enumerate(files):
|
||||
(base / n).write_bytes(_mkbytes(n, 1500 + i * 10))
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
c1 = len(b1.get("assets", []))
|
||||
|
||||
# Seed again -> count must stay the same
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
async with http.get(api_base + "/api/assets", params={"include_tags": "unit-tests," + scope}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2
|
||||
c2 = len(b2.get("assets", []))
|
||||
assert c1 == c2 == len(files)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_sync_seed_nested_dirs_produce_parent_tags(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
scope = f"nest-{uuid.uuid4().hex[:6]}"
|
||||
# nested: unit-tests / scope / a / b / c / deep.txt
|
||||
deep_dir = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope / "a" / "b" / "c"
|
||||
deep_dir.mkdir(parents=True, exist_ok=True)
|
||||
(deep_dir / "deep.txt").write_bytes(scope.encode())
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": "deep.txt"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
assets = body.get("assets", [])
|
||||
assert assets, "seeded asset not found"
|
||||
tags = set(assets[0].get("tags", []))
|
||||
# Must include all parent parts as tags + the root
|
||||
for must in {root, "unit-tests", scope, "a", "b", "c"}:
|
||||
assert must in tags
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_concurrent_seed_hashing_same_file_no_dupes(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Create a single seed file, then schedule two scans back-to-back.
|
||||
Expect: no duplicate AssetInfos, a single hashed asset, and no scan failure.
|
||||
"""
|
||||
scope = f"conc-seed-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_concurrent.bin"
|
||||
|
||||
base = _base_for(root, comfy_tmp_base_dir) / "unit-tests" / scope
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
(base / name).write_bytes(b"Z" * 2048)
|
||||
|
||||
await trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
s1, s2 = await asyncio.gather(
|
||||
http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}),
|
||||
http.post(api_base + "/api/assets/scan/schedule", json={"roots": [root]}),
|
||||
)
|
||||
await s1.read()
|
||||
await s2.read()
|
||||
assert s1.status in (200, 202)
|
||||
assert s2.status in (200, 202)
|
||||
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
) as r:
|
||||
b = await r.json()
|
||||
assert r.status == 200, b
|
||||
matches = [a for a in b.get("assets", []) if a.get("name") == name]
|
||||
assert len(matches) == 1
|
||||
assert matches[0].get("asset_hash"), "Seed should have been hashed into an Asset"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_cache_state_retarget_on_content_change_asset_info_stays(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Start with hashed H1 (AssetInfo A1). Replace file bytes on disk to become H2.
|
||||
After scan: AssetCacheState points to H2; A1 still references H1; downloading A1 -> 404.
|
||||
"""
|
||||
scope = f"retarget-{uuid.uuid4().hex[:6]}"
|
||||
name = "content_change.bin"
|
||||
d1 = make_asset_bytes("v1-" + scope, 2048)
|
||||
|
||||
a1 = await asset_factory(name, [root, "unit-tests", scope], {}, d1)
|
||||
aid = a1["id"]
|
||||
h1 = a1["asset_hash"]
|
||||
|
||||
p = comfy_tmp_base_dir / root / "unit-tests" / scope / get_asset_filename(a1["asset_hash"], ".bin")
|
||||
assert p.exists()
|
||||
|
||||
# Change the bytes in place to force a new content hash (H2)
|
||||
d2 = make_asset_bytes("v2-" + scope, 3072)
|
||||
p.write_bytes(d2)
|
||||
|
||||
# Scan to verify and retarget the state; reconcilers run after scan
|
||||
await run_scan_and_wait(root)
|
||||
|
||||
# AssetInfo still on the old content identity (H1)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
g = await rg.json()
|
||||
assert rg.status == 200, g
|
||||
assert g.get("asset_hash") == h1
|
||||
|
||||
# Download must fail until a state exists for H1 (we removed the only one by retarget)
|
||||
async with http.get(f"{api_base}/api/assets/{aid}/content") as dl:
|
||||
body = await dl.json()
|
||||
assert dl.status == 404, body
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
228
tests-assets/test_tags.py
Normal file
228
tests-assets/test_tags.py
Normal file
@ -0,0 +1,228 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
# Include zero-usage tags by default
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [t["name"] for t in b3["tags"]]
|
||||
assert "unit-tests" in names3
|
||||
assert "models" not in names3 # filtered out by prefix
|
||||
|
||||
# Order by name ascending should be stable
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4:
|
||||
b4 = await r4.json()
|
||||
assert r4.status == 200
|
||||
names4 = [t["name"] for t in b4["tags"]]
|
||||
assert names4 == sorted(names4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
|
||||
# Baseline: system tags exist when include_zero (default) is true
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "500"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
assert "models" in names and "checkpoints" in names
|
||||
|
||||
# Create a short-lived asset under input with a unique custom tag
|
||||
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
|
||||
name = "tag_seed.bin"
|
||||
_asset = await asset_factory(
|
||||
name,
|
||||
["input", "unit-tests", scope, custom_tag],
|
||||
{},
|
||||
make_asset_bytes(name, 512),
|
||||
)
|
||||
|
||||
# While the asset exists, the custom tag must appear when include_zero=false
|
||||
async with http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert custom_tag in used_names
|
||||
|
||||
# Delete the asset so the tag usage drops to zero
|
||||
async with http.delete(f"{api_base}/api/assets/{_asset['id']}") as rd:
|
||||
assert rd.status == 204
|
||||
|
||||
# Now the custom tag must not be returned when include_zero=false
|
||||
async with http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
) as r3:
|
||||
body3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names_after = [t["name"] for t in body3["tags"]]
|
||||
assert custom_tag not in names_after
|
||||
assert not names_after # filtered view should be empty now
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add tags with duplicates and mixed case
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
g = await rg.json()
|
||||
assert rg.status == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
|
||||
g2 = await rg2.json()
|
||||
assert rg2.status == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}) as r_add:
|
||||
add_body = await r_add.json()
|
||||
assert r_add.status == 200, add_body
|
||||
|
||||
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "order_only_bbb.safetensors",
|
||||
"tags": ["input", "unit-tests", "orderbbb"],
|
||||
"user_metadata": {},
|
||||
}
|
||||
async with http.post(f"{api_base}/api/assets/from-hash", json=payload) as r_copy:
|
||||
copy_body = await r_copy.json()
|
||||
assert r_copy.status == 201, copy_body
|
||||
|
||||
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
|
||||
# because it has higher usage (2 vs 1).
|
||||
async with http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
names1 = [t["name"] for t in b1["tags"]]
|
||||
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
|
||||
# Both must be present within the prefix subset
|
||||
assert "orderaaa" in names1 and "orderbbb" in names1
|
||||
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
|
||||
assert counts1["orderbbb"] >= counts1["orderaaa"]
|
||||
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
|
||||
assert names1.index("orderbbb") < names1.index("orderaaa")
|
||||
|
||||
# 2) name_asc: lexical order should flip the relative order
|
||||
async with http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
|
||||
) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2
|
||||
names2 = [t["name"] for t in b2["tags"]]
|
||||
assert "orderaaa" in names2 and "orderbbb" in names2
|
||||
assert names2.index("orderaaa") < names2.index("orderbbb")
|
||||
|
||||
# 3) invalid limit rejected (existing negative case retained)
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add with empty list
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Remove with wrong type
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 400
|
||||
assert b2["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# metadata_filter provided as JSON array should be rejected (must be object)
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"metadata_filter": json.dumps([{"x": 1}])},
|
||||
) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_prefix_treats_underscore_literal(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
|
||||
base = f"pref_{uuid.uuid4().hex[:6]}"
|
||||
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
|
||||
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
|
||||
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
|
||||
await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
|
||||
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 200, body
|
||||
names = [t["name"] for t in body["tags"]]
|
||||
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
|
||||
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
|
||||
assert body["total"] == 1
|
||||
325
tests-assets/test_uploads.py
Normal file
325
tests-assets/test_uploads.py
Normal file
@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(tags))
|
||||
form1.add_field("name", name)
|
||||
form1.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
a1 = await r1.json()
|
||||
assert r1.status == 201, a1
|
||||
assert a1["created_new"] is True
|
||||
|
||||
# Second upload with the same data and name should return created_new == False and the same asset
|
||||
form2 = aiohttp.FormData()
|
||||
form2.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form2.add_field("tags", json.dumps(tags))
|
||||
form2.add_field("name", name)
|
||||
form2.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
a2 = await r2.json()
|
||||
assert r2.status == 200, a2
|
||||
assert a2["created_new"] is False
|
||||
assert a2["asset_hash"] == a1["asset_hash"]
|
||||
assert a2["id"] == a1["id"] # old reference
|
||||
|
||||
# Third upload with the same data but new name should return created_new == False and the new AssetReference
|
||||
form3 = aiohttp.FormData()
|
||||
form3.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form3.add_field("tags", json.dumps(tags))
|
||||
form3.add_field("name", name + "_d")
|
||||
form3.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form3) as r2:
|
||||
a3 = await r2.json()
|
||||
assert r2.status == 200, a3
|
||||
assert a3["created_new"] is False
|
||||
assert a3["asset_hash"] == a1["asset_hash"]
|
||||
assert a3["id"] != a1["id"] # old reference
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
meta = {}
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(tags))
|
||||
form1.add_field("name", name)
|
||||
form1.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
form2 = aiohttp.FormData(default_to_multipart=True)
|
||||
form2.add_field("hash", h)
|
||||
form2.add_field("tags", json.dumps(tags))
|
||||
form2.add_field("name", "fastpath_copy.safetensors")
|
||||
form2.add_field("user_metadata", json.dumps({"purpose": "copy"}))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2 # fast path returns 200 with created_new == False
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_fastpath_with_known_hash_and_file(
|
||||
http: aiohttp.ClientSession, api_base: str
|
||||
):
|
||||
# Seed
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", b"C" * 128, filename="seed.safetensors", content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"]))
|
||||
form1.add_field("name", "seed.safetensors")
|
||||
form1.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Send both file and hash of existing content -> server must drain file and create from hash (200)
|
||||
form2 = aiohttp.FormData()
|
||||
form2.add_field("file", b"ignored" * 10, filename="ignored.bin", content_type="application/octet-stream")
|
||||
form2.add_field("hash", h)
|
||||
form2.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "fp"]))
|
||||
form2.add_field("name", "copy_from_hash.safetensors")
|
||||
form2.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"B" * 256, filename="merge.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", "models,checkpoints") # CSV
|
||||
form.add_field("tags", json.dumps(["unit-tests", "alpha"])) # JSON array in second field
|
||||
form.add_field("name", "merge.safetensors")
|
||||
form.add_field("user_metadata", json.dumps({"u": 1}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r1:
|
||||
created = await r1.json()
|
||||
assert r1.status in (200, 201), created
|
||||
aid = created["id"]
|
||||
|
||||
# Verify all tags are present on the resource
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
detail = await rg.json()
|
||||
assert rg.status == 200, detail
|
||||
tags = set(detail["tags"])
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_concurrent_upload_identical_bytes_different_names(
|
||||
root: str,
|
||||
http: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two concurrent uploads of identical bytes but different names.
|
||||
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
|
||||
"""
|
||||
scope = f"concupload-{uuid.uuid4().hex[:6]}"
|
||||
name1, name2 = "cu_a.bin", "cu_b.bin"
|
||||
data = make_asset_bytes("concurrent", 4096)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
def _form(name: str) -> aiohttp.FormData:
|
||||
f = aiohttp.FormData()
|
||||
f.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
f.add_field("tags", json.dumps(tags))
|
||||
f.add_field("name", name)
|
||||
f.add_field("user_metadata", json.dumps({}))
|
||||
return f
|
||||
|
||||
r1, r2 = await asyncio.gather(
|
||||
http.post(api_base + "/api/assets", data=_form(name1)),
|
||||
http.post(api_base + "/api/assets", data=_form(name2)),
|
||||
)
|
||||
b1, b2 = await r1.json(), await r2.json()
|
||||
assert r1.status in (200, 201), b1
|
||||
assert r2.status in (200, 201), b2
|
||||
assert b1["asset_hash"] == b2["asset_hash"]
|
||||
assert b1["id"] != b2["id"]
|
||||
|
||||
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
|
||||
assert created_flags == [False, True]
|
||||
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
|
||||
) as rl:
|
||||
bl = await rl.json()
|
||||
assert rl.status == 200, bl
|
||||
names = [a["name"] for a in bl.get("assets", [])]
|
||||
assert set([name1, name2]).issubset(names)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
|
||||
payload = {
|
||||
"hash": "blake3:" + "0" * 64,
|
||||
"name": "nonexistent.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_zero_byte_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"", filename="empty.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"]))
|
||||
form.add_field("name", "empty.safetensors")
|
||||
form.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_invalid_root_tag_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 64, filename="badroot.bin", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["not-a-root", "whatever"]))
|
||||
form.add_field("name", "badroot.bin")
|
||||
form.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_user_metadata_must_be_json(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 128, filename="badmeta.bin", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "edge"]))
|
||||
form.add_field("name", "badmeta.bin")
|
||||
form.add_field("user_metadata", "{not json}") # invalid
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str):
|
||||
async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 415
|
||||
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"]))
|
||||
form.add_field("name", "x.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"]))
|
||||
form.add_field("name", "m.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_models_requires_category(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 64, filename="nocat.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models"])) # missing category
|
||||
form.add_field("name", "nocat.safetensors")
|
||||
form.add_field("user_metadata", json.dumps({}))
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream")
|
||||
# '..' should be rejected by destination resolver
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]))
|
||||
form.add_field("name", "evil.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
async def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||
root: str,
|
||||
http,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two uploads use the same tags and the same display name but different bytes.
|
||||
With hash-based filenames, they must NOT overwrite each other. Both assets
|
||||
remain accessible and serve their original content.
|
||||
"""
|
||||
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
|
||||
display_name = "same_display.bin"
|
||||
|
||||
d1 = make_asset_bytes(scope + "-v1", 1536)
|
||||
d2 = make_asset_bytes(scope + "-v2", 2048)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
first = await asset_factory(display_name, tags, {}, d1)
|
||||
second = await asset_factory(display_name, tags, {}, d2)
|
||||
|
||||
assert first["id"] != second["id"]
|
||||
assert first["asset_hash"] != second["asset_hash"] # different content
|
||||
assert first["name"] == second["name"] == display_name
|
||||
|
||||
# Both must be independently retrievable
|
||||
async with http.get(f"{api_base}/api/assets/{first['id']}/content") as r1:
|
||||
b1 = await r1.read()
|
||||
assert r1.status == 200
|
||||
assert b1 == d1
|
||||
async with http.get(f"{api_base}/api/assets/{second['id']}/content") as r2:
|
||||
b2 = await r2.read()
|
||||
assert r2.status == 200
|
||||
assert b2 == d2
|
||||
@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||
img = Image.new('RGB', (100, 100), 'white')
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
||||
|
||||
safetensors_file = tmp_path / "test_model.safetensors"
|
||||
header_bytes = json.dumps({
|
||||
"__metadata__": {
|
||||
"ssmd_cover_images": json.dumps([img_b64])
|
||||
}
|
||||
}).encode('utf-8')
|
||||
length_bytes = struct.pack('<Q', len(header_bytes))
|
||||
with open(safetensors_file, 'wb') as f:
|
||||
f.write(length_bytes)
|
||||
f.write(header_bytes)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
||||
|
||||
# Verify response
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
# Verify the response contains valid image data
|
||||
img_bytes = BytesIO(await response.read())
|
||||
img = Image.open(img_bytes)
|
||||
assert img.format
|
||||
assert img.format.lower() == 'webp'
|
||||
|
||||
# Clean up
|
||||
img.close()
|
||||
@ -1,253 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.model_processor import ModelProcessor
|
||||
from app.database.models import Model, Base
|
||||
import os
|
||||
|
||||
# Test data constants
|
||||
TEST_MODEL_TYPE = "checkpoints"
|
||||
TEST_URL = "http://example.com/model.safetensors"
|
||||
TEST_FILE_NAME = "model.safetensors"
|
||||
TEST_EXPECTED_HASH = "abc123"
|
||||
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
||||
|
||||
|
||||
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
||||
"""Helper to create a test model in the database."""
|
||||
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
return model
|
||||
|
||||
|
||||
def setup_mock_hash_calculation(model_processor, hash_value):
|
||||
"""Helper to setup hash calculation mocks."""
|
||||
mock_hash = MagicMock()
|
||||
mock_hash.hexdigest.return_value = hash_value
|
||||
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
||||
|
||||
|
||||
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
||||
"""Helper to verify model exists in database with correct attributes."""
|
||||
db_model = session.query(Model).filter_by(path=file_name).first()
|
||||
assert db_model is not None
|
||||
if expected_hash:
|
||||
assert db_model.hash == expected_hash
|
||||
if expected_type:
|
||||
assert db_model.type == expected_type
|
||||
return db_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
# Configure in-memory database
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
yield engine
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(db_engine):
|
||||
Session = sessionmaker(bind=db_engine)
|
||||
session = Session()
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_relative_path():
|
||||
with patch("app.model_processor.get_relative_path") as mock:
|
||||
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_full_path():
|
||||
with patch("app.model_processor.get_full_path") as mock:
|
||||
mock.return_value = TEST_DESTINATION_PATH
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
||||
with patch("app.model_processor.create_session", return_value=db_session):
|
||||
with patch("app.model_processor.can_create_session", return_value=True):
|
||||
processor = ModelProcessor()
|
||||
# Setup test state
|
||||
processor.removed_files = []
|
||||
processor.downloaded_files = []
|
||||
processor.file_exists = {}
|
||||
|
||||
def mock_download_file(url, destination_path, hasher):
|
||||
processor.downloaded_files.append((url, destination_path))
|
||||
processor.file_exists[destination_path] = True
|
||||
# Simulate writing some data to the file
|
||||
test_data = b"test data"
|
||||
hasher.update(test_data)
|
||||
|
||||
def mock_remove_file(file_path):
|
||||
processor.removed_files.append(file_path)
|
||||
if file_path in processor.file_exists:
|
||||
del processor.file_exists[file_path]
|
||||
|
||||
# Setup common patches
|
||||
file_exists_patch = patch.object(
|
||||
processor,
|
||||
"_file_exists",
|
||||
side_effect=lambda path: processor.file_exists.get(path, False),
|
||||
)
|
||||
file_size_patch = patch.object(
|
||||
processor,
|
||||
"_get_file_size",
|
||||
side_effect=lambda path: (
|
||||
1000 if processor.file_exists.get(path, False) else 0
|
||||
),
|
||||
)
|
||||
download_file_patch = patch.object(
|
||||
processor, "_download_file", side_effect=mock_download_file
|
||||
)
|
||||
remove_file_patch = patch.object(
|
||||
processor, "_remove_file", side_effect=mock_remove_file
|
||||
)
|
||||
|
||||
with (
|
||||
file_exists_patch,
|
||||
file_size_patch,
|
||||
download_file_patch,
|
||||
remove_file_patch,
|
||||
):
|
||||
yield processor
|
||||
|
||||
|
||||
def test_ensure_downloaded_invalid_extension(model_processor):
|
||||
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
||||
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
||||
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
||||
# Ensure that a file with the same hash but from a different source is not downloaded again
|
||||
SOURCE_URL = "https://example.com/other.sft"
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that a file with a different hash raises an error
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_downloaded_new_file(model_processor, db_session):
|
||||
# Ensure that a new file is downloaded
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
assert len(model_processor.downloaded_files) == 1
|
||||
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
||||
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
||||
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
|
||||
|
||||
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that download that results in a different hash raises an error
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Downloaded file hash .* does not match expected hash .*",
|
||||
):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE,
|
||||
TEST_URL,
|
||||
TEST_FILE_NAME,
|
||||
TEST_EXPECTED_HASH,
|
||||
)
|
||||
|
||||
assert len(model_processor.removed_files) == 1
|
||||
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
||||
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
||||
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
||||
|
||||
|
||||
def test_process_file_without_hash(model_processor, db_session):
|
||||
# Test processing file without provided hash
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash(model_processor, db_session):
|
||||
# Test retrieving model by hash
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
||||
# Test retrieving model by hash and type
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.type == TEST_MODEL_TYPE
|
||||
|
||||
|
||||
def test_retrieve_hash(model_processor, db_session):
|
||||
# Test retrieving hash for existing model
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
with patch.object(
|
||||
model_processor,
|
||||
"_validate_path",
|
||||
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
||||
):
|
||||
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
||||
assert result == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_validate_file_extension_valid_extensions(model_processor):
|
||||
# Test all valid file extensions
|
||||
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
||||
for ext in valid_extensions:
|
||||
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
||||
|
||||
|
||||
def test_process_file_existing_without_source_url(model_processor, db_session):
|
||||
# Test processing an existing file that needs its source URL updated
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
||||
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.source_url == TEST_URL
|
||||
|
||||
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
||||
assert db_model.source_url == TEST_URL
|
||||
Loading…
Reference in New Issue
Block a user