mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Compare commits
11 Commits
f5bc6b3fd0
...
9f53409c3b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f53409c3b | ||
|
|
a6fed841b5 | ||
|
|
6592bffc60 | ||
|
|
604b00c1f6 | ||
|
|
3819bf238f | ||
|
|
20c57cbc6a | ||
|
|
c72236dd32 | ||
|
|
098b3dd5d7 | ||
|
|
b3ec1b3f05 | ||
|
|
9fedc506b5 | ||
|
|
51f553386d |
174
alembic_db/versions/0001_assets.py
Normal file
174
alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
"""
|
||||||
|
Initial assets schema
|
||||||
|
Revision ID: 0001_assets
|
||||||
|
Revises: None
|
||||||
|
Create Date: 2025-12-10 00:00:00
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
revision = "0001_assets"
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ASSETS: content identity
|
||||||
|
op.create_table(
|
||||||
|
"assets",
|
||||||
|
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||||
|
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
|
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||||
|
)
|
||||||
|
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||||
|
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||||
|
|
||||||
|
# ASSETS_INFO: user-visible references
|
||||||
|
op.create_table(
|
||||||
|
"assets_info",
|
||||||
|
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||||
|
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||||
|
sa.Column("name", sa.String(length=512), nullable=False),
|
||||||
|
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||||
|
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||||
|
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
|
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||||
|
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||||
|
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||||
|
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||||
|
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||||
|
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||||
|
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||||
|
|
||||||
|
# TAGS: normalized tag vocabulary
|
||||||
|
op.create_table(
|
||||||
|
"tags",
|
||||||
|
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||||
|
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||||
|
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||||
|
|
||||||
|
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||||
|
op.create_table(
|
||||||
|
"asset_info_tags",
|
||||||
|
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||||
|
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||||
|
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||||
|
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||||
|
|
||||||
|
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||||
|
op.create_table(
|
||||||
|
"asset_cache_state",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||||
|
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||||
|
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||||
|
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||||
|
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||||
|
|
||||||
|
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||||
|
op.create_table(
|
||||||
|
"asset_info_meta",
|
||||||
|
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("key", sa.String(length=256), nullable=False),
|
||||||
|
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||||
|
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||||
|
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||||
|
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||||
|
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||||
|
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||||
|
|
||||||
|
# Tags vocabulary
|
||||||
|
tags_table = sa.table(
|
||||||
|
"tags",
|
||||||
|
sa.column("name", sa.String(length=512)),
|
||||||
|
sa.column("tag_type", sa.String()),
|
||||||
|
)
|
||||||
|
op.bulk_insert(
|
||||||
|
tags_table,
|
||||||
|
[
|
||||||
|
{"name": "models", "tag_type": "system"},
|
||||||
|
{"name": "input", "tag_type": "system"},
|
||||||
|
{"name": "output", "tag_type": "system"},
|
||||||
|
|
||||||
|
{"name": "configs", "tag_type": "system"},
|
||||||
|
{"name": "checkpoints", "tag_type": "system"},
|
||||||
|
{"name": "loras", "tag_type": "system"},
|
||||||
|
{"name": "vae", "tag_type": "system"},
|
||||||
|
{"name": "text_encoders", "tag_type": "system"},
|
||||||
|
{"name": "diffusion_models", "tag_type": "system"},
|
||||||
|
{"name": "clip_vision", "tag_type": "system"},
|
||||||
|
{"name": "style_models", "tag_type": "system"},
|
||||||
|
{"name": "embeddings", "tag_type": "system"},
|
||||||
|
{"name": "diffusers", "tag_type": "system"},
|
||||||
|
{"name": "vae_approx", "tag_type": "system"},
|
||||||
|
{"name": "controlnet", "tag_type": "system"},
|
||||||
|
{"name": "gligen", "tag_type": "system"},
|
||||||
|
{"name": "upscale_models", "tag_type": "system"},
|
||||||
|
{"name": "hypernetworks", "tag_type": "system"},
|
||||||
|
{"name": "photomaker", "tag_type": "system"},
|
||||||
|
{"name": "classifiers", "tag_type": "system"},
|
||||||
|
|
||||||
|
{"name": "encoder", "tag_type": "system"},
|
||||||
|
{"name": "decoder", "tag_type": "system"},
|
||||||
|
|
||||||
|
{"name": "missing", "tag_type": "system"},
|
||||||
|
{"name": "rescan", "tag_type": "system"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||||
|
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||||
|
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||||
|
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||||
|
op.drop_table("asset_info_meta")
|
||||||
|
|
||||||
|
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||||
|
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||||
|
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||||
|
op.drop_table("asset_cache_state")
|
||||||
|
|
||||||
|
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||||
|
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||||
|
op.drop_table("asset_info_tags")
|
||||||
|
|
||||||
|
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||||
|
op.drop_table("tags")
|
||||||
|
|
||||||
|
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||||
|
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||||
|
op.drop_table("assets_info")
|
||||||
|
|
||||||
|
op.drop_index("uq_assets_hash", table_name="assets")
|
||||||
|
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||||
|
op.drop_table("assets")
|
||||||
102
app/assets/api/routes.py
Normal file
102
app/assets/api/routes.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import app.assets.manager as manager
|
||||||
|
from app import user_manager
|
||||||
|
from app.assets.api import schemas_in
|
||||||
|
from app.assets.helpers import get_query_dict
|
||||||
|
|
||||||
|
ROUTES = web.RouteTableDef()
|
||||||
|
USER_MANAGER: user_manager.UserManager | None = None
|
||||||
|
|
||||||
|
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||||
|
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||||
|
|
||||||
|
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||||
|
global USER_MANAGER
|
||||||
|
USER_MANAGER = user_manager_instance
|
||||||
|
app.add_routes(ROUTES)
|
||||||
|
|
||||||
|
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||||
|
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||||
|
|
||||||
|
|
||||||
|
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||||
|
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/assets")
|
||||||
|
async def list_assets(request: web.Request) -> web.Response:
|
||||||
|
"""
|
||||||
|
GET request to list assets.
|
||||||
|
"""
|
||||||
|
query_dict = get_query_dict(request)
|
||||||
|
try:
|
||||||
|
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_QUERY", ve)
|
||||||
|
|
||||||
|
payload = manager.list_assets(
|
||||||
|
include_tags=q.include_tags,
|
||||||
|
exclude_tags=q.exclude_tags,
|
||||||
|
name_contains=q.name_contains,
|
||||||
|
metadata_filter=q.metadata_filter,
|
||||||
|
limit=q.limit,
|
||||||
|
offset=q.offset,
|
||||||
|
sort=q.sort,
|
||||||
|
order=q.order,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return web.json_response(payload.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
|
async def get_asset(request: web.Request) -> web.Response:
|
||||||
|
"""
|
||||||
|
GET request to get an asset's info as JSON.
|
||||||
|
"""
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
result = manager.get_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/tags")
|
||||||
|
async def get_tags(request: web.Request) -> web.Response:
|
||||||
|
"""
|
||||||
|
GET request to list all tags based on query parameters.
|
||||||
|
"""
|
||||||
|
query_map = dict(request.rel_url.query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||||
|
except ValidationError as e:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.list_tags(
|
||||||
|
prefix=query.prefix,
|
||||||
|
limit=query.limit,
|
||||||
|
offset=query.offset,
|
||||||
|
order=query.order,
|
||||||
|
include_zero=query.include_zero,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return web.json_response(result.model_dump(mode="json"))
|
||||||
94
app/assets/api/schemas_in.py
Normal file
94
app/assets/api/schemas_in.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
conint,
|
||||||
|
field_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ListAssetsQuery(BaseModel):
|
||||||
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
|
name_contains: str | None = None
|
||||||
|
|
||||||
|
# Accept either a JSON string (query param) or a dict
|
||||||
|
metadata_filter: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
limit: conint(ge=1, le=500) = 20
|
||||||
|
offset: conint(ge=0) = 0
|
||||||
|
|
||||||
|
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||||
|
order: Literal["asc", "desc"] = "desc"
|
||||||
|
|
||||||
|
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _split_csv_tags(cls, v):
|
||||||
|
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
return [t.strip() for t in v.split(",") if t.strip()]
|
||||||
|
if isinstance(v, list):
|
||||||
|
out: list[str] = []
|
||||||
|
for item in v:
|
||||||
|
if isinstance(item, str):
|
||||||
|
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||||
|
return out
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("metadata_filter", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_metadata_json(cls, v):
|
||||||
|
if v is None or isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
if isinstance(v, str) and v.strip():
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("metadata_filter must be a JSON object")
|
||||||
|
return parsed
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TagsListQuery(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
prefix: str | None = Field(None, min_length=1, max_length=256)
|
||||||
|
limit: int = Field(100, ge=1, le=1000)
|
||||||
|
offset: int = Field(0, ge=0, le=10_000_000)
|
||||||
|
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||||
|
include_zero: bool = True
|
||||||
|
|
||||||
|
@field_validator("prefix")
|
||||||
|
@classmethod
|
||||||
|
def normalize_prefix(cls, v: str | None) -> str | None:
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
v = v.strip()
|
||||||
|
return v.lower() or None
|
||||||
|
|
||||||
|
|
||||||
|
class SetPreviewBody(BaseModel):
|
||||||
|
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||||
|
preview_id: str | None = None
|
||||||
|
|
||||||
|
@field_validator("preview_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _norm_uuid(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
s = str(v).strip()
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
uuid.UUID(s)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("preview_id must be a UUID")
|
||||||
|
return s
|
||||||
60
app/assets/api/schemas_out.py
Normal file
60
app/assets/api/schemas_out.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
|
|
||||||
|
|
||||||
|
class AssetSummary(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
asset_hash: str | None = None
|
||||||
|
size: int | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
preview_url: str | None = None
|
||||||
|
created_at: datetime | None = None
|
||||||
|
updated_at: datetime | None = None
|
||||||
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||||
|
def _ser_dt(self, v: datetime | None, _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetsList(BaseModel):
|
||||||
|
assets: list[AssetSummary]
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AssetDetail(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
asset_hash: str | None = None
|
||||||
|
size: int | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
preview_id: str | None = None
|
||||||
|
created_at: datetime | None = None
|
||||||
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("created_at", "last_access_time")
|
||||||
|
def _ser_dt(self, v: datetime | None, _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class TagUsage(BaseModel):
|
||||||
|
name: str
|
||||||
|
count: int
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class TagsList(BaseModel):
|
||||||
|
tags: list[TagUsage] = Field(default_factory=list)
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
188
app/assets/database/bulk_ops.py
Normal file
188
app/assets/database/bulk_ops.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import sqlalchemy
|
||||||
|
from typing import Iterable
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.dialects import sqlite
|
||||||
|
|
||||||
|
from app.assets.helpers import utcnow
|
||||||
|
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||||
|
|
||||||
|
MAX_BIND_PARAMS = 800
|
||||||
|
|
||||||
|
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||||
|
if not rows:
|
||||||
|
return []
|
||||||
|
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||||
|
for i in range(0, len(rows), rows_per_stmt):
|
||||||
|
yield rows[i:i + rows_per_stmt]
|
||||||
|
|
||||||
|
def _iter_chunks(seq, n: int):
|
||||||
|
for i in range(0, len(seq), n):
|
||||||
|
yield seq[i:i + n]
|
||||||
|
|
||||||
|
def _rows_per_stmt(cols: int) -> int:
|
||||||
|
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||||
|
|
||||||
|
|
||||||
|
def seed_from_paths_batch(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
specs: list[dict],
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> dict:
|
||||||
|
"""Each spec is a dict with keys:
|
||||||
|
- abs_path: str
|
||||||
|
- size_bytes: int
|
||||||
|
- mtime_ns: int
|
||||||
|
- info_name: str
|
||||||
|
- tags: list[str]
|
||||||
|
- fname: Optional[str]
|
||||||
|
"""
|
||||||
|
if not specs:
|
||||||
|
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||||
|
|
||||||
|
now = utcnow()
|
||||||
|
asset_rows: list[dict] = []
|
||||||
|
state_rows: list[dict] = []
|
||||||
|
path_to_asset: dict[str, str] = {}
|
||||||
|
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||||
|
path_list: list[str] = []
|
||||||
|
|
||||||
|
for sp in specs:
|
||||||
|
ap = os.path.abspath(sp["abs_path"])
|
||||||
|
aid = str(uuid.uuid4())
|
||||||
|
iid = str(uuid.uuid4())
|
||||||
|
path_list.append(ap)
|
||||||
|
path_to_asset[ap] = aid
|
||||||
|
|
||||||
|
asset_rows.append(
|
||||||
|
{
|
||||||
|
"id": aid,
|
||||||
|
"hash": None,
|
||||||
|
"size_bytes": sp["size_bytes"],
|
||||||
|
"mime_type": None,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
state_rows.append(
|
||||||
|
{
|
||||||
|
"asset_id": aid,
|
||||||
|
"file_path": ap,
|
||||||
|
"mtime_ns": sp["mtime_ns"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
asset_to_info[aid] = {
|
||||||
|
"id": iid,
|
||||||
|
"owner_id": owner_id,
|
||||||
|
"name": sp["info_name"],
|
||||||
|
"asset_id": aid,
|
||||||
|
"preview_id": None,
|
||||||
|
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
"last_access_time": now,
|
||||||
|
"_tags": sp["tags"],
|
||||||
|
"_filename": sp["fname"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# insert all seed Assets (hash=NULL)
|
||||||
|
ins_asset = sqlite.insert(Asset)
|
||||||
|
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||||
|
session.execute(ins_asset, chunk)
|
||||||
|
|
||||||
|
# try to claim AssetCacheState (file_path)
|
||||||
|
winners_by_path: set[str] = set()
|
||||||
|
ins_state = (
|
||||||
|
sqlite.insert(AssetCacheState)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||||
|
.returning(AssetCacheState.file_path)
|
||||||
|
)
|
||||||
|
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||||
|
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
|
||||||
|
|
||||||
|
all_paths_set = set(path_list)
|
||||||
|
losers_by_path = all_paths_set - winners_by_path
|
||||||
|
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||||
|
if lost_assets: # losers get their Asset removed
|
||||||
|
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||||
|
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||||
|
|
||||||
|
if not winners_by_path:
|
||||||
|
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||||
|
|
||||||
|
# insert AssetInfo only for winners
|
||||||
|
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||||
|
ins_info = (
|
||||||
|
sqlite.insert(AssetInfo)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||||
|
.returning(AssetInfo.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
inserted_info_ids: set[str] = set()
|
||||||
|
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||||
|
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
|
||||||
|
|
||||||
|
# build and insert tag + meta rows for the AssetInfo
|
||||||
|
tag_rows: list[dict] = []
|
||||||
|
meta_rows: list[dict] = []
|
||||||
|
if inserted_info_ids:
|
||||||
|
for row in winner_info_rows:
|
||||||
|
iid = row["id"]
|
||||||
|
if iid not in inserted_info_ids:
|
||||||
|
continue
|
||||||
|
for t in row["_tags"]:
|
||||||
|
tag_rows.append({
|
||||||
|
"asset_info_id": iid,
|
||||||
|
"tag_name": t,
|
||||||
|
"origin": "automatic",
|
||||||
|
"added_at": now,
|
||||||
|
})
|
||||||
|
if row["_filename"]:
|
||||||
|
meta_rows.append(
|
||||||
|
{
|
||||||
|
"asset_info_id": iid,
|
||||||
|
"key": "filename",
|
||||||
|
"ordinal": 0,
|
||||||
|
"val_str": row["_filename"],
|
||||||
|
"val_num": None,
|
||||||
|
"val_bool": None,
|
||||||
|
"val_json": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||||
|
return {
|
||||||
|
"inserted_infos": len(inserted_info_ids),
|
||||||
|
"won_states": len(winners_by_path),
|
||||||
|
"lost_states": len(losers_by_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bulk_insert_tags_and_meta(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
tag_rows: list[dict],
|
||||||
|
meta_rows: list[dict],
|
||||||
|
max_bind_params: int,
|
||||||
|
) -> None:
|
||||||
|
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||||
|
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||||
|
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||||
|
"""
|
||||||
|
if tag_rows:
|
||||||
|
ins_links = (
|
||||||
|
sqlite.insert(AssetInfoTag)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||||
|
)
|
||||||
|
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||||
|
session.execute(ins_links, chunk)
|
||||||
|
if meta_rows:
|
||||||
|
ins_meta = (
|
||||||
|
sqlite.insert(AssetInfoMeta)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||||
|
session.execute(ins_meta, chunk)
|
||||||
233
app/assets/database/models.py
Normal file
233
app/assets/database/models.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from sqlalchemy import (
|
||||||
|
JSON,
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
CheckConstraint,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
Index,
|
||||||
|
Integer,
|
||||||
|
Numeric,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.assets.helpers import utcnow
|
||||||
|
from app.database.models import to_dict, Base
|
||||||
|
|
||||||
|
|
||||||
|
class Asset(Base):
|
||||||
|
__tablename__ = "assets"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||||
|
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=False), nullable=False, default=utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
infos: Mapped[list[AssetInfo]] = relationship(
|
||||||
|
"AssetInfo",
|
||||||
|
back_populates="asset",
|
||||||
|
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||||
|
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||||
|
cascade="all,delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||||
|
"AssetInfo",
|
||||||
|
back_populates="preview_asset",
|
||||||
|
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||||
|
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||||
|
viewonly=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||||
|
back_populates="asset",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("uq_assets_hash", "hash", unique=True),
|
||||||
|
Index("ix_assets_mime_type", "mime_type"),
|
||||||
|
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
|
return to_dict(self, include_none=include_none)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||||
|
|
||||||
|
|
||||||
|
class AssetCacheState(Base):
|
||||||
|
__tablename__ = "asset_cache_state"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||||
|
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||||
|
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||||
|
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||||
|
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
|
return to_dict(self, include_none=include_none)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||||
|
|
||||||
|
|
||||||
|
class AssetInfo(Base):
|
||||||
|
__tablename__ = "assets_info"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
|
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
|
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||||
|
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||||
|
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||||
|
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||||
|
|
||||||
|
asset: Mapped[Asset] = relationship(
|
||||||
|
"Asset",
|
||||||
|
back_populates="infos",
|
||||||
|
foreign_keys=[asset_id],
|
||||||
|
lazy="selectin",
|
||||||
|
)
|
||||||
|
preview_asset: Mapped[Asset | None] = relationship(
|
||||||
|
"Asset",
|
||||||
|
back_populates="preview_of",
|
||||||
|
foreign_keys=[preview_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||||
|
back_populates="asset_info",
|
||||||
|
cascade="all,delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||||
|
back_populates="asset_info",
|
||||||
|
cascade="all,delete-orphan",
|
||||||
|
passive_deletes=True,
|
||||||
|
overlaps="tags,asset_infos",
|
||||||
|
)
|
||||||
|
|
||||||
|
tags: Mapped[list[Tag]] = relationship(
|
||||||
|
secondary="asset_info_tags",
|
||||||
|
back_populates="asset_infos",
|
||||||
|
lazy="selectin",
|
||||||
|
viewonly=True,
|
||||||
|
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||||
|
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||||
|
Index("ix_assets_info_owner_id", "owner_id"),
|
||||||
|
Index("ix_assets_info_asset_id", "asset_id"),
|
||||||
|
Index("ix_assets_info_name", "name"),
|
||||||
|
Index("ix_assets_info_created_at", "created_at"),
|
||||||
|
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
|
data = to_dict(self, include_none=include_none)
|
||||||
|
data["tags"] = [t.name for t in self.tags]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||||
|
|
||||||
|
|
||||||
|
class AssetInfoMeta(Base):
|
||||||
|
__tablename__ = "asset_info_meta"
|
||||||
|
|
||||||
|
asset_info_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||||
|
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||||
|
|
||||||
|
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||||
|
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
|
||||||
|
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||||
|
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||||
|
|
||||||
|
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_asset_info_meta_key", "key"),
|
||||||
|
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||||
|
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||||
|
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssetInfoTag(Base):
|
||||||
|
__tablename__ = "asset_info_tags"
|
||||||
|
|
||||||
|
asset_info_id: Mapped[str] = mapped_column(
|
||||||
|
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
tag_name: Mapped[str] = mapped_column(
|
||||||
|
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||||
|
)
|
||||||
|
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||||
|
added_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=False), nullable=False, default=utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||||
|
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||||
|
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Base):
|
||||||
|
__tablename__ = "tags"
|
||||||
|
|
||||||
|
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||||
|
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||||
|
|
||||||
|
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||||
|
back_populates="tag",
|
||||||
|
overlaps="asset_infos,tags",
|
||||||
|
)
|
||||||
|
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||||
|
secondary="asset_info_tags",
|
||||||
|
back_populates="tags",
|
||||||
|
viewonly=True,
|
||||||
|
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_tags_tag_type", "tag_type"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Tag {self.name}>"
|
||||||
267
app/assets/database/queries.py
Normal file
267
app/assets/database/queries.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from collections import defaultdict
|
||||||
|
from sqlalchemy import select, exists, func
|
||||||
|
from sqlalchemy.orm import Session, contains_eager, noload
|
||||||
|
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||||
|
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||||
|
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||||
|
owner_id = (owner_id or "").strip()
|
||||||
|
if owner_id == "":
|
||||||
|
return AssetInfo.owner_id == ""
|
||||||
|
return AssetInfo.owner_id.in_(["", owner_id])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tag_filters(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||||
|
include_tags = normalize_tags(include_tags)
|
||||||
|
exclude_tags = normalize_tags(exclude_tags)
|
||||||
|
|
||||||
|
if include_tags:
|
||||||
|
for tag_name in include_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists().where(
|
||||||
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||||
|
& (AssetInfoTag.tag_name == tag_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists().where(
|
||||||
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||||
|
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def apply_metadata_filter(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""Apply filters using asset_info_meta projection table."""
|
||||||
|
if not metadata_filter:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||||
|
return sa.exists().where(
|
||||||
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||||
|
AssetInfoMeta.key == key,
|
||||||
|
*preds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||||
|
if value is None:
|
||||||
|
no_row_for_key = sa.not_(
|
||||||
|
sa.exists().where(
|
||||||
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||||
|
AssetInfoMeta.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
null_row = _exists_for_pred(
|
||||||
|
key,
|
||||||
|
AssetInfoMeta.val_json.is_(None),
|
||||||
|
AssetInfoMeta.val_str.is_(None),
|
||||||
|
AssetInfoMeta.val_num.is_(None),
|
||||||
|
AssetInfoMeta.val_bool.is_(None),
|
||||||
|
)
|
||||||
|
return sa.or_(no_row_for_key, null_row)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
from decimal import Decimal
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||||
|
|
||||||
|
for k, v in metadata_filter.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||||
|
if ors:
|
||||||
|
stmt = stmt.where(sa.or_(*ors))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an asset with a given hash exists in database.
|
||||||
|
"""
|
||||||
|
row = (
|
||||||
|
session.execute(
|
||||||
|
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||||
|
return session.get(AssetInfo, asset_info_id)
|
||||||
|
|
||||||
|
def list_asset_infos_page(
|
||||||
|
session: Session,
|
||||||
|
owner_id: str = "",
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
sort: str = "created_at",
|
||||||
|
order: str = "desc",
|
||||||
|
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||||
|
base = (
|
||||||
|
select(AssetInfo)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
|
)
|
||||||
|
|
||||||
|
if name_contains:
|
||||||
|
escaped, esc = escape_like_prefix(name_contains)
|
||||||
|
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
|
sort = (sort or "created_at").lower()
|
||||||
|
order = (order or "desc").lower()
|
||||||
|
sort_map = {
|
||||||
|
"name": AssetInfo.name,
|
||||||
|
"created_at": AssetInfo.created_at,
|
||||||
|
"updated_at": AssetInfo.updated_at,
|
||||||
|
"last_access_time": AssetInfo.last_access_time,
|
||||||
|
"size": Asset.size_bytes,
|
||||||
|
}
|
||||||
|
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||||
|
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||||
|
|
||||||
|
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||||
|
|
||||||
|
count_stmt = (
|
||||||
|
select(sa.func.count())
|
||||||
|
.select_from(AssetInfo)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
|
)
|
||||||
|
if name_contains:
|
||||||
|
escaped, esc = escape_like_prefix(name_contains)
|
||||||
|
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
|
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||||
|
|
||||||
|
infos = (session.execute(base)).unique().scalars().all()
|
||||||
|
|
||||||
|
id_list: list[str] = [i.id for i in infos]
|
||||||
|
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||||
|
if id_list:
|
||||||
|
rows = session.execute(
|
||||||
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||||
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||||
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||||
|
)
|
||||||
|
for aid, tag_name in rows.all():
|
||||||
|
tag_map[aid].append(tag_name)
|
||||||
|
|
||||||
|
return infos, tag_map, total
|
||||||
|
|
||||||
|
def fetch_asset_info_asset_and_tags(
|
||||||
|
session: Session,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||||
|
stmt = (
|
||||||
|
select(AssetInfo, Asset, Tag.name)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||||
|
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||||
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
.order_by(Tag.name.asc())
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (session.execute(stmt)).all()
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_info, first_asset, _ = rows[0]
|
||||||
|
tags: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for _info, _asset, tag_name in rows:
|
||||||
|
if tag_name and tag_name not in seen:
|
||||||
|
seen.add(tag_name)
|
||||||
|
tags.append(tag_name)
|
||||||
|
return first_info, first_asset, tags
|
||||||
|
|
||||||
|
def list_tags_with_usage(
|
||||||
|
session: Session,
|
||||||
|
prefix: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
include_zero: bool = True,
|
||||||
|
order: str = "count_desc",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[list[tuple[str, str, int]], int]:
|
||||||
|
counts_sq = (
|
||||||
|
select(
|
||||||
|
AssetInfoTag.tag_name.label("tag_name"),
|
||||||
|
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||||
|
)
|
||||||
|
.select_from(AssetInfoTag)
|
||||||
|
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
|
.group_by(AssetInfoTag.tag_name)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
select(
|
||||||
|
Tag.name,
|
||||||
|
Tag.tag_type,
|
||||||
|
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||||
|
)
|
||||||
|
.select_from(Tag)
|
||||||
|
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||||
|
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||||
|
|
||||||
|
if not include_zero:
|
||||||
|
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||||
|
|
||||||
|
if order == "name_asc":
|
||||||
|
q = q.order_by(Tag.name.asc())
|
||||||
|
else:
|
||||||
|
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||||
|
|
||||||
|
total_q = select(func.count()).select_from(Tag)
|
||||||
|
if prefix:
|
||||||
|
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||||
|
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||||
|
if not include_zero:
|
||||||
|
total_q = total_q.where(
|
||||||
|
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||||
|
total = (session.execute(total_q)).scalar_one()
|
||||||
|
|
||||||
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||||
|
return rows_norm, int(total or 0)
|
||||||
62
app/assets/database/tags.py
Normal file
62
app/assets/database/tags.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.dialects import sqlite
|
||||||
|
|
||||||
|
from app.assets.helpers import normalize_tags, utcnow
|
||||||
|
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||||
|
wanted = normalize_tags(list(names))
|
||||||
|
if not wanted:
|
||||||
|
return
|
||||||
|
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||||
|
ins = (
|
||||||
|
sqlite.insert(Tag)
|
||||||
|
.values(rows)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
|
)
|
||||||
|
return session.execute(ins)
|
||||||
|
|
||||||
|
def add_missing_tag_for_asset_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
origin: str = "automatic",
|
||||||
|
) -> None:
|
||||||
|
select_rows = (
|
||||||
|
sqlalchemy.select(
|
||||||
|
AssetInfo.id.label("asset_info_id"),
|
||||||
|
sqlalchemy.literal("missing").label("tag_name"),
|
||||||
|
sqlalchemy.literal(origin).label("origin"),
|
||||||
|
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||||
|
)
|
||||||
|
.where(AssetInfo.asset_id == asset_id)
|
||||||
|
.where(
|
||||||
|
sqlalchemy.not_(
|
||||||
|
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.execute(
|
||||||
|
sqlite.insert(AssetInfoTag)
|
||||||
|
.from_select(
|
||||||
|
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||||
|
select_rows,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_missing_tag_for_asset_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
) -> None:
|
||||||
|
session.execute(
|
||||||
|
sqlalchemy.delete(AssetInfoTag).where(
|
||||||
|
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||||
|
AssetInfoTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
76
app/assets/hashing.py
Normal file
76
app/assets/hashing.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from blake3 import blake3
|
||||||
|
from typing import IO, Union
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||||
|
|
||||||
|
# NOTE: this allows hashing different representations of a file-like object
|
||||||
|
def blake3_hash(
|
||||||
|
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||||
|
chunk_size: int = DEFAULT_CHUNK,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||||
|
- a filename (str/bytes) or PathLike
|
||||||
|
- an open binary file object
|
||||||
|
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||||
|
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||||
|
reading and will attempt to restore the original position afterward.
|
||||||
|
"""
|
||||||
|
if hasattr(fp, "read"):
|
||||||
|
return _hash_file_obj(fp, chunk_size)
|
||||||
|
|
||||||
|
with open(os.fspath(fp), "rb") as f:
|
||||||
|
return _hash_file_obj(f, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
|
async def blake3_hash_async(
|
||||||
|
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||||
|
chunk_size: int = DEFAULT_CHUNK,
|
||||||
|
) -> str:
|
||||||
|
"""Async wrapper for ``blake3_hash_sync``.
|
||||||
|
Uses a worker thread so the event loop remains responsive.
|
||||||
|
"""
|
||||||
|
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||||
|
if hasattr(fp, "read"):
|
||||||
|
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||||
|
|
||||||
|
def _worker() -> str:
|
||||||
|
with open(os.fspath(fp), "rb") as f:
|
||||||
|
return _hash_file_obj(f, chunk_size)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_worker)
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||||
|
"""
|
||||||
|
Hash an already-open binary file object by streaming in chunks.
|
||||||
|
- Seeks to the beginning before reading (if supported).
|
||||||
|
- Restores the original position afterward (if tell/seek are supported).
|
||||||
|
"""
|
||||||
|
if chunk_size <= 0:
|
||||||
|
chunk_size = DEFAULT_CHUNK
|
||||||
|
|
||||||
|
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||||
|
orig_pos = None
|
||||||
|
if hasattr(file_obj, "tell"):
|
||||||
|
orig_pos = file_obj.tell()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(file_obj, "seek"):
|
||||||
|
# seek to the beginning before reading
|
||||||
|
file_obj.seek(0)
|
||||||
|
|
||||||
|
h = blake3()
|
||||||
|
while True:
|
||||||
|
chunk = file_obj.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
h.update(chunk)
|
||||||
|
return h.hexdigest()
|
||||||
|
finally:
|
||||||
|
# restore original position in file object, if needed
|
||||||
|
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||||
|
file_obj.seek(orig_pos)
|
||||||
216
app/assets/helpers.py
Normal file
216
app/assets/helpers.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
from aiohttp import web
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Any
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
|
RootType = Literal["models", "input", "output"]
|
||||||
|
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||||
|
|
||||||
|
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Gets a dictionary of query parameters from the request.
|
||||||
|
|
||||||
|
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||||
|
"""
|
||||||
|
query_dict = {
|
||||||
|
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||||
|
for key in request.query.keys()
|
||||||
|
}
|
||||||
|
return query_dict
|
||||||
|
|
||||||
|
def list_tree(base_dir: str) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
base_abs = os.path.abspath(base_dir)
|
||||||
|
if not os.path.isdir(base_abs):
|
||||||
|
return out
|
||||||
|
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||||
|
for name in filenames:
|
||||||
|
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def prefixes_for_root(root: RootType) -> list[str]:
|
||||||
|
if root == "models":
|
||||||
|
bases: list[str] = []
|
||||||
|
for _bucket, paths in get_comfy_models_folders():
|
||||||
|
bases.extend(paths)
|
||||||
|
return [os.path.abspath(p) for p in bases]
|
||||||
|
if root == "input":
|
||||||
|
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||||
|
if root == "output":
|
||||||
|
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||||
|
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||||
|
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||||
|
"""
|
||||||
|
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||||
|
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||||
|
return s, escape
|
||||||
|
|
||||||
|
def fast_asset_file_check(
|
||||||
|
*,
|
||||||
|
mtime_db: int | None,
|
||||||
|
size_db: int | None,
|
||||||
|
stat_result: os.stat_result,
|
||||||
|
) -> bool:
|
||||||
|
if mtime_db is None:
|
||||||
|
return False
|
||||||
|
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||||
|
if int(mtime_db) != int(actual_mtime_ns):
|
||||||
|
return False
|
||||||
|
sz = int(size_db or 0)
|
||||||
|
if sz > 0:
|
||||||
|
return int(stat_result.st_size) == sz
|
||||||
|
return True
|
||||||
|
|
||||||
|
def utcnow() -> datetime:
|
||||||
|
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||||
|
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
|
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||||
|
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||||
|
|
||||||
|
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||||
|
*any* of its base paths lies under the Comfy `models_dir`.
|
||||||
|
"""
|
||||||
|
targets: list[tuple[str, list[str]]] = []
|
||||||
|
models_root = os.path.abspath(folder_paths.models_dir)
|
||||||
|
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||||
|
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||||
|
targets.append((name, paths))
|
||||||
|
return targets
|
||||||
|
|
||||||
|
def compute_relative_filename(file_path: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Return the model's path relative to the last well-known folder (the model category),
|
||||||
|
using forward slashes, eg:
|
||||||
|
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||||
|
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||||
|
|
||||||
|
For non-model paths, returns None.
|
||||||
|
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
p = Path(rel_path)
|
||||||
|
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if root_category == "models":
|
||||||
|
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||||
|
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||||
|
return "/".join(inside)
|
||||||
|
return "/".join(parts) # input/output: keep all parts
|
||||||
|
|
||||||
|
|
||||||
|
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||||
|
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||||
|
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||||
|
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||||
|
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(root_category, relative_path_inside_that_root)
|
||||||
|
For 'models', the relative path is prefixed with the category name:
|
||||||
|
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||||
|
"""
|
||||||
|
fp_abs = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
def _is_within(child: str, parent: str) -> bool:
|
||||||
|
try:
|
||||||
|
return os.path.commonpath([child, parent]) == parent
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _rel(child: str, parent: str) -> str:
|
||||||
|
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||||
|
|
||||||
|
# 1) input
|
||||||
|
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||||
|
if _is_within(fp_abs, input_base):
|
||||||
|
return "input", _rel(fp_abs, input_base)
|
||||||
|
|
||||||
|
# 2) output
|
||||||
|
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||||
|
if _is_within(fp_abs, output_base):
|
||||||
|
return "output", _rel(fp_abs, output_base)
|
||||||
|
|
||||||
|
# 3) models (check deepest matching base to avoid ambiguity)
|
||||||
|
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||||
|
for bucket, bases in get_comfy_models_folders():
|
||||||
|
for b in bases:
|
||||||
|
base_abs = os.path.abspath(b)
|
||||||
|
if not _is_within(fp_abs, base_abs):
|
||||||
|
continue
|
||||||
|
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||||
|
if best is None or cand[0] > best[0]:
|
||||||
|
best = cand
|
||||||
|
|
||||||
|
if best is not None:
|
||||||
|
_, bucket, rel_inside = best
|
||||||
|
combined = os.path.join(bucket, rel_inside)
|
||||||
|
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||||
|
|
||||||
|
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||||
|
|
||||||
|
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||||
|
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||||
|
- The returned `name` is the base filename with extension from the relative path.
|
||||||
|
- The returned `tags` are:
|
||||||
|
[root_category] + parent folders of the relative path (in order)
|
||||||
|
For 'models', this means:
|
||||||
|
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||||
|
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||||
|
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||||
|
"""
|
||||||
|
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||||
|
p = Path(some_path)
|
||||||
|
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||||
|
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||||
|
|
||||||
|
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize a list of tags by:
|
||||||
|
- Stripping whitespace and converting to lowercase.
|
||||||
|
- Removing duplicates.
|
||||||
|
"""
|
||||||
|
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
|
||||||
|
def collect_models_files() -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
for folder_name, bases in get_comfy_models_folders():
|
||||||
|
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||||
|
for rel_path in rel_files:
|
||||||
|
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||||
|
if not abs_path:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(abs_path)
|
||||||
|
allowed = False
|
||||||
|
for b in bases:
|
||||||
|
base_abs = os.path.abspath(b)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||||
|
allowed = True
|
||||||
|
break
|
||||||
|
if allowed:
|
||||||
|
out.append(abs_path)
|
||||||
|
return out
|
||||||
123
app/assets/manager.py
Normal file
123
app/assets/manager.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from app.database.db import create_session
|
||||||
|
from app.assets.api import schemas_out
|
||||||
|
from app.assets.database.queries import (
|
||||||
|
asset_exists_by_hash,
|
||||||
|
fetch_asset_info_asset_and_tags,
|
||||||
|
list_asset_infos_page,
|
||||||
|
list_tags_with_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_sort_field(requested: str | None) -> str:
|
||||||
|
if not requested:
|
||||||
|
return "created_at"
|
||||||
|
v = requested.lower()
|
||||||
|
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||||
|
return v
|
||||||
|
return "created_at"
|
||||||
|
|
||||||
|
|
||||||
|
def asset_exists(asset_hash: str) -> bool:
|
||||||
|
with create_session() as session:
|
||||||
|
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||||
|
|
||||||
|
def list_assets(
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
sort: str = "created_at",
|
||||||
|
order: str = "desc",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetsList:
|
||||||
|
sort = _safe_sort_field(sort)
|
||||||
|
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
infos, tag_map, total = list_asset_infos_page(
|
||||||
|
session,
|
||||||
|
owner_id=owner_id,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
name_contains=name_contains,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
sort=sort,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
|
||||||
|
summaries: list[schemas_out.AssetSummary] = []
|
||||||
|
for info in infos:
|
||||||
|
asset = info.asset
|
||||||
|
tags = tag_map.get(info.id, [])
|
||||||
|
summaries.append(
|
||||||
|
schemas_out.AssetSummary(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash if asset else None,
|
||||||
|
size=int(asset.size_bytes) if asset else None,
|
||||||
|
mime_type=asset.mime_type if asset else None,
|
||||||
|
tags=tags,
|
||||||
|
preview_url=f"/api/assets/{info.id}/content",
|
||||||
|
created_at=info.created_at,
|
||||||
|
updated_at=info.updated_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return schemas_out.AssetsList(
|
||||||
|
assets=summaries,
|
||||||
|
total=total,
|
||||||
|
has_more=(offset + len(summaries)) < total,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||||
|
with create_session() as session:
|
||||||
|
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not res:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
info, asset, tag_names = res
|
||||||
|
preview_id = info.preview_id
|
||||||
|
|
||||||
|
return schemas_out.AssetDetail(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash if asset else None,
|
||||||
|
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||||
|
mime_type=asset.mime_type if asset else None,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_tags(
|
||||||
|
prefix: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
order: str = "count_desc",
|
||||||
|
include_zero: bool = True,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.TagsList:
|
||||||
|
limit = max(1, min(1000, limit))
|
||||||
|
offset = max(0, offset)
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
rows, total = list_tags_with_usage(
|
||||||
|
session,
|
||||||
|
prefix=prefix,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
include_zero=include_zero,
|
||||||
|
order=order,
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||||
|
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||||
229
app/assets/scanner.py
Normal file
229
app/assets/scanner.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
import contextlib
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from app.database.db import create_session, dependencies_available
|
||||||
|
from app.assets.helpers import (
|
||||||
|
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||||
|
list_tree,prefixes_for_root, escape_like_prefix,
|
||||||
|
RootType
|
||||||
|
)
|
||||||
|
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||||
|
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||||
|
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||||
|
|
||||||
|
|
||||||
|
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Scan the given roots and seed the assets into the database.
|
||||||
|
"""
|
||||||
|
if not dependencies_available():
|
||||||
|
if enable_logging:
|
||||||
|
logging.warning("Database dependencies not available, skipping assets scan")
|
||||||
|
return
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
created = 0
|
||||||
|
skipped_existing = 0
|
||||||
|
paths: list[str] = []
|
||||||
|
try:
|
||||||
|
existing_paths: set[str] = set()
|
||||||
|
for r in roots:
|
||||||
|
try:
|
||||||
|
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||||
|
if survivors:
|
||||||
|
existing_paths.update(survivors)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||||
|
|
||||||
|
if "models" in roots:
|
||||||
|
paths.extend(collect_models_files())
|
||||||
|
if "input" in roots:
|
||||||
|
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||||
|
if "output" in roots:
|
||||||
|
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||||
|
|
||||||
|
specs: list[dict] = []
|
||||||
|
tag_pool: set[str] = set()
|
||||||
|
for p in paths:
|
||||||
|
abs_p = os.path.abspath(p)
|
||||||
|
if abs_p in existing_paths:
|
||||||
|
skipped_existing += 1
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
# skip empty files
|
||||||
|
if not stat_p.st_size:
|
||||||
|
continue
|
||||||
|
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||||
|
specs.append(
|
||||||
|
{
|
||||||
|
"abs_path": abs_p,
|
||||||
|
"size_bytes": stat_p.st_size,
|
||||||
|
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": compute_relative_filename(abs_p),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for t in tags:
|
||||||
|
tag_pool.add(t)
|
||||||
|
# if no file specs, nothing to do
|
||||||
|
if not specs:
|
||||||
|
return
|
||||||
|
with create_session() as sess:
|
||||||
|
if tag_pool:
|
||||||
|
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||||
|
|
||||||
|
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||||
|
created += result["inserted_infos"]
|
||||||
|
sess.commit()
|
||||||
|
finally:
|
||||||
|
if enable_logging:
|
||||||
|
logging.info(
|
||||||
|
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||||
|
roots,
|
||||||
|
time.perf_counter() - t_start,
|
||||||
|
created,
|
||||||
|
skipped_existing,
|
||||||
|
len(paths),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fast_db_consistency_pass(
|
||||||
|
root: RootType,
|
||||||
|
*,
|
||||||
|
collect_existing_paths: bool = False,
|
||||||
|
update_missing_tags: bool = False,
|
||||||
|
) -> set[str] | None:
|
||||||
|
"""Fast DB+FS pass for a root:
|
||||||
|
- Toggle needs_verify per state using fast check
|
||||||
|
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||||
|
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||||
|
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||||
|
- Optionally return surviving absolute paths
|
||||||
|
"""
|
||||||
|
prefixes = prefixes_for_root(root)
|
||||||
|
if not prefixes:
|
||||||
|
return set() if collect_existing_paths else None
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base += os.sep
|
||||||
|
escaped, esc = escape_like_prefix(base)
|
||||||
|
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||||
|
|
||||||
|
with create_session() as sess:
|
||||||
|
rows = (
|
||||||
|
sess.execute(
|
||||||
|
sqlalchemy.select(
|
||||||
|
AssetCacheState.id,
|
||||||
|
AssetCacheState.file_path,
|
||||||
|
AssetCacheState.mtime_ns,
|
||||||
|
AssetCacheState.needs_verify,
|
||||||
|
AssetCacheState.asset_id,
|
||||||
|
Asset.hash,
|
||||||
|
Asset.size_bytes,
|
||||||
|
)
|
||||||
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
|
.where(sqlalchemy.or_(*conds))
|
||||||
|
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
by_asset: dict[str, dict] = {}
|
||||||
|
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||||
|
acc = by_asset.get(aid)
|
||||||
|
if acc is None:
|
||||||
|
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||||
|
by_asset[aid] = acc
|
||||||
|
|
||||||
|
fast_ok = False
|
||||||
|
try:
|
||||||
|
exists = True
|
||||||
|
fast_ok = fast_asset_file_check(
|
||||||
|
mtime_db=mtime_db,
|
||||||
|
size_db=acc["size_db"],
|
||||||
|
stat_result=os.stat(fp, follow_symlinks=True),
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
exists = False
|
||||||
|
except OSError:
|
||||||
|
exists = False
|
||||||
|
|
||||||
|
acc["states"].append({
|
||||||
|
"sid": sid,
|
||||||
|
"fp": fp,
|
||||||
|
"exists": exists,
|
||||||
|
"fast_ok": fast_ok,
|
||||||
|
"needs_verify": bool(needs_verify),
|
||||||
|
})
|
||||||
|
|
||||||
|
to_set_verify: list[int] = []
|
||||||
|
to_clear_verify: list[int] = []
|
||||||
|
stale_state_ids: list[int] = []
|
||||||
|
survivors: set[str] = set()
|
||||||
|
|
||||||
|
for aid, acc in by_asset.items():
|
||||||
|
a_hash = acc["hash"]
|
||||||
|
states = acc["states"]
|
||||||
|
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||||
|
all_missing = all(not s["exists"] for s in states)
|
||||||
|
|
||||||
|
for s in states:
|
||||||
|
if not s["exists"]:
|
||||||
|
continue
|
||||||
|
if s["fast_ok"] and s["needs_verify"]:
|
||||||
|
to_clear_verify.append(s["sid"])
|
||||||
|
if not s["fast_ok"] and not s["needs_verify"]:
|
||||||
|
to_set_verify.append(s["sid"])
|
||||||
|
|
||||||
|
if a_hash is None:
|
||||||
|
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||||
|
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||||
|
asset = sess.get(Asset, aid)
|
||||||
|
if asset:
|
||||||
|
sess.delete(asset)
|
||||||
|
else:
|
||||||
|
for s in states:
|
||||||
|
if s["exists"]:
|
||||||
|
survivors.add(os.path.abspath(s["fp"]))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||||
|
for s in states:
|
||||||
|
if not s["exists"]:
|
||||||
|
stale_state_ids.append(s["sid"])
|
||||||
|
if update_missing_tags:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||||
|
elif update_missing_tags:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||||
|
|
||||||
|
for s in states:
|
||||||
|
if s["exists"]:
|
||||||
|
survivors.add(os.path.abspath(s["fp"]))
|
||||||
|
|
||||||
|
if stale_state_ids:
|
||||||
|
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||||
|
if to_set_verify:
|
||||||
|
sess.execute(
|
||||||
|
sqlalchemy.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.id.in_(to_set_verify))
|
||||||
|
.values(needs_verify=True)
|
||||||
|
)
|
||||||
|
if to_clear_verify:
|
||||||
|
sess.execute(
|
||||||
|
sqlalchemy.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||||
|
.values(needs_verify=False)
|
||||||
|
)
|
||||||
|
sess.commit()
|
||||||
|
return survivors if collect_existing_paths else None
|
||||||
@ -1,14 +1,21 @@
|
|||||||
from sqlalchemy.orm import declarative_base
|
from typing import Any
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
Base = declarative_base()
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||||
def to_dict(obj):
|
|
||||||
fields = obj.__table__.columns.keys()
|
fields = obj.__table__.columns.keys()
|
||||||
return {
|
out: dict[str, Any] = {}
|
||||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
for field in fields:
|
||||||
for field in fields
|
val = getattr(obj, field)
|
||||||
if (val := getattr(obj, field))
|
if val is None and not include_none:
|
||||||
}
|
continue
|
||||||
|
if isinstance(val, datetime):
|
||||||
|
out[field] = val.isoformat()
|
||||||
|
else:
|
||||||
|
out[field] = val
|
||||||
|
return out
|
||||||
|
|
||||||
# TODO: Define models here
|
# TODO: Define models here
|
||||||
|
|||||||
@ -224,6 +224,7 @@ database_default_path = os.path.abspath(
|
|||||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||||
)
|
)
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
|
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||||
"""
|
"""
|
||||||
|
if solver_type not in {"phi_1", "phi_2"}:
|
||||||
|
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
seed = extra_args.get("seed", None)
|
seed = extra_args.get("seed", None)
|
||||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||||
|
|
||||||
# Step 2
|
# Step 2
|
||||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
if solver_type == "phi_1":
|
||||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||||
|
elif solver_type == "phi_2":
|
||||||
|
b2 = ei_h_phi_2(-h_eta) / r
|
||||||
|
b1 = ei_h_phi_1(-h_eta) - b2
|
||||||
|
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||||
|
|
||||||
if inject_noise:
|
if inject_noise:
|
||||||
segment_factor = (r - 1) * h * eta
|
segment_factor = (r - 1) * h * eta
|
||||||
sde_noise = sde_noise * segment_factor.exp()
|
sde_noise = sde_noise * segment_factor.exp()
|
||||||
|
|||||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
get_sampler = execute
|
get_sampler = execute
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerSEEDS2(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerSEEDS2",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||||
|
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||||
|
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||||
|
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||||
|
sampler_name = "seeds_2"
|
||||||
|
sampler = comfy.samplers.ksampler(
|
||||||
|
sampler_name,
|
||||||
|
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||||
|
)
|
||||||
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
|
|
||||||
class Noise_EmptyNoise:
|
class Noise_EmptyNoise:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
|||||||
SamplerDPMAdaptative,
|
SamplerDPMAdaptative,
|
||||||
SamplerER_SDE,
|
SamplerER_SDE,
|
||||||
SamplerSASolver,
|
SamplerSASolver,
|
||||||
|
SamplerSEEDS2,
|
||||||
SplitSigmas,
|
SplitSigmas,
|
||||||
SplitSigmasDenoise,
|
SplitSigmasDenoise,
|
||||||
FlipSigmas,
|
FlipSigmas,
|
||||||
|
|||||||
3
main.py
3
main.py
@ -7,6 +7,7 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.scanner import seed_assets
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
import logging
|
import logging
|
||||||
@ -326,6 +327,8 @@ def setup_database():
|
|||||||
from app.database.db import init_db, dependencies_available
|
from app.database.db import init_db, dependencies_available
|
||||||
if dependencies_available():
|
if dependencies_available():
|
||||||
init_db()
|
init_db()
|
||||||
|
if not args.disable_assets_autoscan:
|
||||||
|
seed_assets(["models"], enable_logging=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,8 @@ import node_helpers
|
|||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager, parse_version
|
from app.frontend_management import FrontendManager, parse_version
|
||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
|
from app.assets.scanner import seed_assets
|
||||||
|
from app.assets.api.routes import register_assets_system
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@ -228,6 +230,7 @@ class PromptServer():
|
|||||||
else args.front_end_root
|
else args.front_end_root
|
||||||
)
|
)
|
||||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||||
|
register_assets_system(self.app, self.user_manager)
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
@ -676,6 +679,7 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
|
seed_assets(["models"])
|
||||||
with folder_paths.cache_helper:
|
with folder_paths.cache_helper:
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user