mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
960 lines
31 KiB
Python
960 lines
31 KiB
Python
import contextlib
|
|
import os
|
|
import logging
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from typing import Any, Sequence, Optional, Iterable
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, delete, exists, func
|
|
from sqlalchemy.orm import contains_eager
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
|
from .timeutil import utcnow
|
|
from .._assets_helpers import normalize_tags
|
|
|
|
|
|
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
|
row = (
|
|
await session.execute(
|
|
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
|
)
|
|
).first()
|
|
return row is not None
|
|
|
|
|
|
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
|
|
return await session.get(Asset, asset_hash)
|
|
|
|
|
|
async def check_fs_asset_exists_quick(
|
|
session,
|
|
*,
|
|
file_path: str,
|
|
size_bytes: Optional[int] = None,
|
|
mtime_ns: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
|
AND (if provided) mtime_ns matches stored locator-state,
|
|
AND (if provided) size_bytes matches verified size when known.
|
|
"""
|
|
locator = os.path.abspath(file_path)
|
|
|
|
stmt = select(sa.literal(True)).select_from(Asset)
|
|
|
|
conditions = [
|
|
Asset.storage_backend == "fs",
|
|
Asset.storage_locator == locator,
|
|
]
|
|
|
|
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
|
# If verified size is 0 (unknown), we don't force equality.
|
|
if size_bytes is not None:
|
|
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
|
|
|
# If mtime_ns provided require the locator-state to exist and match.
|
|
if mtime_ns is not None:
|
|
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
|
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
|
|
|
stmt = stmt.where(*conditions).limit(1)
|
|
|
|
row = (await session.execute(stmt)).first()
|
|
return row is not None
|
|
|
|
|
|
async def ingest_fs_asset(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_hash: str,
|
|
abs_path: str,
|
|
size_bytes: int,
|
|
mtime_ns: int,
|
|
mime_type: Optional[str] = None,
|
|
info_name: Optional[str] = None,
|
|
owner_id: str = "",
|
|
preview_hash: Optional[str] = None,
|
|
user_metadata: Optional[dict] = None,
|
|
tags: Sequence[str] = (),
|
|
tag_origin: str = "manual",
|
|
added_by: Optional[str] = None,
|
|
require_existing_tags: bool = False,
|
|
) -> dict:
|
|
"""
|
|
Creates or updates Asset record for a local (fs) asset.
|
|
|
|
Always:
|
|
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
|
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
|
|
|
Optionally (when info_name is provided):
|
|
- Create an AssetInfo (no refcount changes).
|
|
- Link provided tags to that AssetInfo.
|
|
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
|
* If False (default), create unknown tags.
|
|
|
|
Returns flags and ids:
|
|
{
|
|
"asset_created": bool,
|
|
"asset_updated": bool,
|
|
"state_created": bool,
|
|
"state_updated": bool,
|
|
"asset_info_id": int | None,
|
|
}
|
|
"""
|
|
locator = os.path.abspath(abs_path)
|
|
datetime_now = utcnow()
|
|
|
|
out = {
|
|
"asset_created": False,
|
|
"asset_updated": False,
|
|
"state_created": False,
|
|
"state_updated": False,
|
|
"asset_info_id": None,
|
|
}
|
|
|
|
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
|
|
with contextlib.suppress(IntegrityError):
|
|
async with session.begin_nested():
|
|
session.add(
|
|
Asset(
|
|
hash=asset_hash,
|
|
size_bytes=int(size_bytes),
|
|
mime_type=mime_type,
|
|
refcount=0,
|
|
storage_backend="fs",
|
|
storage_locator=locator,
|
|
created_at=datetime_now,
|
|
updated_at=datetime_now,
|
|
)
|
|
)
|
|
await session.flush()
|
|
out["asset_created"] = True
|
|
|
|
if not out["asset_created"]:
|
|
existing = await session.get(Asset, asset_hash)
|
|
if existing is not None:
|
|
changed = False
|
|
if existing.size_bytes != size_bytes:
|
|
existing.size_bytes = size_bytes
|
|
changed = True
|
|
if mime_type and existing.mime_type != mime_type:
|
|
existing.mime_type = mime_type
|
|
changed = True
|
|
if existing.storage_locator != locator:
|
|
existing.storage_locator = locator
|
|
changed = True
|
|
if changed:
|
|
existing.updated_at = datetime_now
|
|
out["asset_updated"] = True
|
|
else:
|
|
logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash)
|
|
|
|
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
|
with contextlib.suppress(IntegrityError):
|
|
async with session.begin_nested():
|
|
session.add(
|
|
AssetLocatorState(
|
|
asset_hash=asset_hash,
|
|
mtime_ns=int(mtime_ns),
|
|
)
|
|
)
|
|
await session.flush()
|
|
out["state_created"] = True
|
|
|
|
if not out["state_created"]:
|
|
state = await session.get(AssetLocatorState, asset_hash)
|
|
if state is not None:
|
|
desired_mtime = int(mtime_ns)
|
|
if state.mtime_ns != desired_mtime:
|
|
state.mtime_ns = desired_mtime
|
|
out["state_updated"] = True
|
|
else:
|
|
logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
|
|
|
# ---- Optional: AssetInfo + tag links ----
|
|
if info_name:
|
|
# 2a) Upsert AssetInfo idempotently on (asset_hash, owner_id, name)
|
|
with contextlib.suppress(IntegrityError):
|
|
async with session.begin_nested():
|
|
info = AssetInfo(
|
|
owner_id=owner_id,
|
|
name=info_name,
|
|
asset_hash=asset_hash,
|
|
preview_hash=preview_hash,
|
|
created_at=datetime_now,
|
|
updated_at=datetime_now,
|
|
last_access_time=datetime_now,
|
|
)
|
|
session.add(info)
|
|
await session.flush() # get info.id
|
|
out["asset_info_id"] = info.id
|
|
|
|
existing_info = (
|
|
await session.execute(
|
|
select(AssetInfo)
|
|
.where(
|
|
AssetInfo.asset_hash == asset_hash,
|
|
AssetInfo.name == info_name,
|
|
(AssetInfo.owner_id == owner_id),
|
|
)
|
|
.limit(1)
|
|
)
|
|
).unique().scalar_one_or_none()
|
|
if not existing_info:
|
|
raise RuntimeError("Failed to update or insert AssetInfo.")
|
|
|
|
if preview_hash is not None and existing_info.preview_hash != preview_hash:
|
|
existing_info.preview_hash = preview_hash
|
|
existing_info.updated_at = datetime_now
|
|
if existing_info.last_access_time < datetime_now:
|
|
existing_info.last_access_time = datetime_now
|
|
await session.flush()
|
|
out["asset_info_id"] = existing_info.id
|
|
|
|
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
|
|
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
|
if norm and out["asset_info_id"] is not None:
|
|
if not require_existing_tags:
|
|
await _ensure_tags_exist(session, norm, tag_type="user")
|
|
|
|
# Which tags exist?
|
|
existing_tag_names = set(
|
|
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
|
)
|
|
missing = [t for t in norm if t not in existing_tag_names]
|
|
if missing and require_existing_tags:
|
|
raise ValueError(f"Unknown tags: {missing}")
|
|
|
|
# Which links already exist?
|
|
existing_links = set(
|
|
tag_name
|
|
for (tag_name,) in (
|
|
await session.execute(
|
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
|
)
|
|
).all()
|
|
)
|
|
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
|
if to_add:
|
|
session.add_all(
|
|
[
|
|
AssetInfoTag(
|
|
asset_info_id=out["asset_info_id"],
|
|
tag_name=t,
|
|
origin=tag_origin,
|
|
added_by=added_by,
|
|
added_at=datetime_now,
|
|
)
|
|
for t in to_add
|
|
]
|
|
)
|
|
await session.flush()
|
|
|
|
# 2c) Rebuild metadata projection if provided
|
|
if user_metadata is not None and out["asset_info_id"] is not None:
|
|
await replace_asset_info_metadata_projection(
|
|
session,
|
|
asset_info_id=out["asset_info_id"],
|
|
user_metadata=user_metadata,
|
|
)
|
|
return out
|
|
|
|
|
|
async def touch_asset_infos_by_fs_path(
|
|
session: AsyncSession,
|
|
*,
|
|
abs_path: str,
|
|
ts: Optional[datetime] = None,
|
|
only_if_newer: bool = True,
|
|
) -> int:
|
|
locator = os.path.abspath(abs_path)
|
|
ts = ts or utcnow()
|
|
|
|
stmt = sa.update(AssetInfo).where(
|
|
sa.exists(
|
|
sa.select(sa.literal(1))
|
|
.select_from(Asset)
|
|
.where(
|
|
Asset.hash == AssetInfo.asset_hash,
|
|
Asset.storage_backend == "fs",
|
|
Asset.storage_locator == locator,
|
|
)
|
|
)
|
|
)
|
|
|
|
if only_if_newer:
|
|
stmt = stmt.where(
|
|
sa.or_(
|
|
AssetInfo.last_access_time.is_(None),
|
|
AssetInfo.last_access_time < ts,
|
|
)
|
|
)
|
|
|
|
stmt = stmt.values(last_access_time=ts)
|
|
|
|
res = await session.execute(stmt)
|
|
return int(res.rowcount or 0)
|
|
|
|
|
|
async def touch_asset_info_by_id(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_info_id: int,
|
|
ts: Optional[datetime] = None,
|
|
only_if_newer: bool = True,
|
|
) -> int:
|
|
ts = ts or utcnow()
|
|
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
|
if only_if_newer:
|
|
stmt = stmt.where(
|
|
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
|
)
|
|
stmt = stmt.values(last_access_time=ts)
|
|
res = await session.execute(stmt)
|
|
return int(res.rowcount or 0)
|
|
|
|
|
|
async def list_asset_infos_page(
|
|
session: AsyncSession,
|
|
*,
|
|
include_tags: Optional[Sequence[str]] = None,
|
|
exclude_tags: Optional[Sequence[str]] = None,
|
|
name_contains: Optional[str] = None,
|
|
metadata_filter: Optional[dict] = None,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
sort: str = "created_at",
|
|
order: str = "desc",
|
|
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
|
"""
|
|
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
|
plus a map of asset_info_id -> [tags], and the total count.
|
|
|
|
We purposely collect tags in a separate (single) query to avoid row explosion.
|
|
"""
|
|
# Clamp
|
|
if limit <= 0:
|
|
limit = 1
|
|
if limit > 100:
|
|
limit = 100
|
|
if offset < 0:
|
|
offset = 0
|
|
|
|
# Build base query
|
|
base = (
|
|
select(AssetInfo)
|
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
|
.options(contains_eager(AssetInfo.asset))
|
|
)
|
|
|
|
# Filters
|
|
if name_contains:
|
|
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
|
|
|
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
|
|
|
base = _apply_metadata_filter(base, metadata_filter)
|
|
|
|
# Sort
|
|
sort = (sort or "created_at").lower()
|
|
order = (order or "desc").lower()
|
|
sort_map = {
|
|
"name": AssetInfo.name,
|
|
"created_at": AssetInfo.created_at,
|
|
"updated_at": AssetInfo.updated_at,
|
|
"last_access_time": AssetInfo.last_access_time,
|
|
"size": Asset.size_bytes,
|
|
}
|
|
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
|
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
|
|
|
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
|
|
|
# Total count (same filters, no ordering/limit/offset)
|
|
count_stmt = (
|
|
select(func.count())
|
|
.select_from(AssetInfo)
|
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
|
)
|
|
if name_contains:
|
|
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
|
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
|
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
|
|
|
total = (await session.execute(count_stmt)).scalar_one()
|
|
|
|
# Fetch rows
|
|
infos = (await session.execute(base)).scalars().unique().all()
|
|
|
|
# Collect tags in bulk (single query)
|
|
id_list = [i.id for i in infos]
|
|
tag_map: dict[int, list[str]] = defaultdict(list)
|
|
if id_list:
|
|
rows = await session.execute(
|
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
|
)
|
|
for aid, tag_name in rows.all():
|
|
tag_map[aid].append(tag_name)
|
|
|
|
return infos, tag_map, total
|
|
|
|
|
|
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
|
|
row = await session.execute(
|
|
select(AssetInfo, Asset)
|
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
|
.where(AssetInfo.id == asset_info_id)
|
|
.limit(1)
|
|
)
|
|
pair = row.first()
|
|
if not pair:
|
|
return None
|
|
return pair[0], pair[1]
|
|
|
|
|
|
async def create_asset_info_for_existing_asset(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_hash: str,
|
|
name: str,
|
|
user_metadata: Optional[dict] = None,
|
|
tags: Optional[Sequence[str]] = None,
|
|
tag_origin: str = "manual",
|
|
added_by: Optional[str] = None,
|
|
) -> AssetInfo:
|
|
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
|
|
now = utcnow()
|
|
info = AssetInfo(
|
|
owner_id="",
|
|
name=name,
|
|
asset_hash=asset_hash,
|
|
preview_hash=None,
|
|
created_at=now,
|
|
updated_at=now,
|
|
last_access_time=now,
|
|
)
|
|
session.add(info)
|
|
await session.flush() # get info.id
|
|
|
|
if user_metadata is not None:
|
|
await replace_asset_info_metadata_projection(
|
|
session, asset_info_id=info.id, user_metadata=user_metadata
|
|
)
|
|
|
|
if tags is not None:
|
|
await set_asset_info_tags(
|
|
session,
|
|
asset_info_id=info.id,
|
|
tags=tags,
|
|
origin=tag_origin,
|
|
added_by=added_by,
|
|
)
|
|
return info
|
|
|
|
|
|
async def set_asset_info_tags(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_info_id: int,
|
|
tags: Sequence[str],
|
|
origin: str = "manual",
|
|
added_by: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Replace the tag set on an AssetInfo with `tags`. Idempotent.
|
|
Creates missing tag names as 'user'.
|
|
"""
|
|
desired = normalize_tags(tags)
|
|
|
|
# current links
|
|
current = set(
|
|
tag_name for (tag_name,) in (
|
|
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
|
).all()
|
|
)
|
|
|
|
to_add = [t for t in desired if t not in current]
|
|
to_remove = [t for t in current if t not in desired]
|
|
|
|
if to_add:
|
|
await _ensure_tags_exist(session, to_add, tag_type="user")
|
|
session.add_all([
|
|
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=utcnow())
|
|
for t in to_add
|
|
])
|
|
await session.flush()
|
|
|
|
if to_remove:
|
|
await session.execute(
|
|
delete(AssetInfoTag)
|
|
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
|
)
|
|
await session.flush()
|
|
|
|
return {"added": to_add, "removed": to_remove, "total": desired}
|
|
|
|
|
|
async def update_asset_info_full(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_info_id: int,
|
|
name: Optional[str] = None,
|
|
tags: Optional[Sequence[str]] = None,
|
|
user_metadata: Optional[dict] = None,
|
|
tag_origin: str = "manual",
|
|
added_by: Optional[str] = None,
|
|
) -> AssetInfo:
|
|
"""
|
|
Update AssetInfo fields:
|
|
- name (if provided)
|
|
- user_metadata blob + rebuild projection (if provided)
|
|
- replace tags with provided set (if provided)
|
|
Returns the updated AssetInfo.
|
|
"""
|
|
info = await session.get(AssetInfo, asset_info_id)
|
|
if not info:
|
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
|
|
|
touched = False
|
|
if name is not None and name != info.name:
|
|
info.name = name
|
|
touched = True
|
|
|
|
if user_metadata is not None:
|
|
await replace_asset_info_metadata_projection(
|
|
session, asset_info_id=asset_info_id, user_metadata=user_metadata
|
|
)
|
|
touched = True
|
|
|
|
if tags is not None:
|
|
await set_asset_info_tags(
|
|
session,
|
|
asset_info_id=asset_info_id,
|
|
tags=tags,
|
|
origin=tag_origin,
|
|
added_by=added_by,
|
|
)
|
|
touched = True
|
|
|
|
if touched and user_metadata is None:
|
|
info.updated_at = utcnow()
|
|
await session.flush()
|
|
|
|
return info
|
|
|
|
|
|
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
|
|
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
|
|
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
|
|
return bool(res.rowcount)
|
|
|
|
|
|
async def replace_asset_info_metadata_projection(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_info_id: int,
|
|
user_metadata: Optional[dict],
|
|
) -> None:
|
|
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
|
|
info = await session.get(AssetInfo, asset_info_id)
|
|
if not info:
|
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
|
|
|
info.user_metadata = user_metadata or {}
|
|
info.updated_at = utcnow()
|
|
await session.flush()
|
|
|
|
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
|
await session.flush()
|
|
|
|
if not user_metadata:
|
|
return
|
|
|
|
rows: list[AssetInfoMeta] = []
|
|
for k, v in user_metadata.items():
|
|
for r in _project_kv(k, v):
|
|
rows.append(
|
|
AssetInfoMeta(
|
|
asset_info_id=asset_info_id,
|
|
key=r["key"],
|
|
ordinal=int(r["ordinal"]),
|
|
val_str=r.get("val_str"),
|
|
val_num=r.get("val_num"),
|
|
val_bool=r.get("val_bool"),
|
|
val_json=r.get("val_json"),
|
|
)
|
|
)
|
|
if rows:
|
|
session.add_all(rows)
|
|
await session.flush()
|
|
|
|
|
|
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]:
|
|
return [
|
|
tag_name
|
|
for (tag_name,) in (
|
|
await session.execute(
|
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
|
)
|
|
).all()
|
|
]
|
|
|
|
|
|
async def list_tags_with_usage(
|
|
session,
|
|
*,
|
|
prefix: Optional[str] = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
include_zero: bool = True,
|
|
order: str = "count_desc", # "count_desc" | "name_asc"
|
|
) -> tuple[list[tuple[str, str, int]], int]:
|
|
"""
|
|
Returns:
|
|
rows: list of (name, tag_type, count)
|
|
total: number of tags matching filter (independent of pagination)
|
|
"""
|
|
# Subquery with counts by tag_name
|
|
counts_sq = (
|
|
select(
|
|
AssetInfoTag.tag_name.label("tag_name"),
|
|
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
|
)
|
|
.group_by(AssetInfoTag.tag_name)
|
|
.subquery()
|
|
)
|
|
|
|
# Base select with LEFT JOIN so we can include zero-usage tags
|
|
q = (
|
|
select(
|
|
Tag.name,
|
|
Tag.tag_type,
|
|
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
|
)
|
|
.select_from(Tag)
|
|
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
|
)
|
|
|
|
# Prefix filter (tags are lowercase by check constraint)
|
|
if prefix:
|
|
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
|
|
|
# Include_zero toggles: if False, drop zero-usage tags
|
|
if not include_zero:
|
|
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
|
|
|
# Ordering
|
|
if order == "name_asc":
|
|
q = q.order_by(Tag.name.asc())
|
|
else: # default "count_desc"
|
|
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
|
|
|
# Total (without limit/offset, same filters)
|
|
total_q = select(func.count()).select_from(Tag)
|
|
if prefix:
|
|
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
|
if not include_zero:
|
|
# count only names that appear in counts subquery
|
|
total_q = total_q.where(
|
|
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
|
)
|
|
|
|
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
|
total = (await session.execute(total_q)).scalar_one()
|
|
|
|
# Normalize counts to int for SQLite/Postgres consistency
|
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
|
return rows_norm, int(total or 0)
|
|
|
|
|
|
async def add_tags_to_asset_info(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_info_id: int,
|
|
tags: Sequence[str],
|
|
origin: str = "manual",
|
|
added_by: Optional[str] = None,
|
|
create_if_missing: bool = True,
|
|
) -> dict:
|
|
"""Adds tags to an AssetInfo.
|
|
If create_if_missing=True, missing tag rows are created as 'user'.
|
|
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
|
"""
|
|
info = await session.get(AssetInfo, asset_info_id)
|
|
if not info:
|
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
|
|
|
norm = normalize_tags(tags)
|
|
if not norm:
|
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
|
return {"added": [], "already_present": [], "total_tags": total}
|
|
|
|
# Ensure tag rows exist if requested.
|
|
if create_if_missing:
|
|
await _ensure_tags_exist(session, norm, tag_type="user")
|
|
|
|
# Snapshot current links
|
|
current = {
|
|
tag_name
|
|
for (tag_name,) in (
|
|
await session.execute(
|
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
|
)
|
|
).all()
|
|
}
|
|
|
|
want = set(norm)
|
|
to_add = sorted(want - current)
|
|
|
|
if to_add:
|
|
async with session.begin_nested() as nested:
|
|
try:
|
|
session.add_all(
|
|
[
|
|
AssetInfoTag(
|
|
asset_info_id=asset_info_id,
|
|
tag_name=t,
|
|
origin=origin,
|
|
added_by=added_by,
|
|
added_at=utcnow(),
|
|
)
|
|
for t in to_add
|
|
]
|
|
)
|
|
await session.flush()
|
|
except IntegrityError:
|
|
await nested.rollback()
|
|
|
|
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
|
|
return {
|
|
"added": sorted(((after - current) & want)),
|
|
"already_present": sorted(want & current),
|
|
"total_tags": sorted(after),
|
|
}
|
|
|
|
|
|
async def remove_tags_from_asset_info(
|
|
session: AsyncSession,
|
|
*,
|
|
asset_info_id: int,
|
|
tags: Sequence[str],
|
|
) -> dict:
|
|
"""Removes tags from an AssetInfo.
|
|
Returns: {"removed": [...], "not_present": [...], "total_tags": [...]}
|
|
"""
|
|
info = await session.get(AssetInfo, asset_info_id)
|
|
if not info:
|
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
|
|
|
norm = normalize_tags(tags)
|
|
if not norm:
|
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
|
return {"removed": [], "not_present": [], "total_tags": total}
|
|
|
|
existing = {
|
|
tag_name
|
|
for (tag_name,) in (
|
|
await session.execute(
|
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
|
)
|
|
).all()
|
|
}
|
|
|
|
to_remove = sorted(set(t for t in norm if t in existing))
|
|
not_present = sorted(set(t for t in norm if t not in existing))
|
|
|
|
if to_remove:
|
|
await session.execute(
|
|
delete(AssetInfoTag)
|
|
.where(
|
|
AssetInfoTag.asset_info_id == asset_info_id,
|
|
AssetInfoTag.tag_name.in_(to_remove),
|
|
)
|
|
)
|
|
await session.flush()
|
|
|
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
|
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
|
|
|
|
|
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
|
wanted = normalize_tags(list(names))
|
|
if not wanted:
|
|
return []
|
|
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
|
by_name = {t.name: t for t in existing}
|
|
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
|
if to_create:
|
|
session.add_all(to_create)
|
|
await session.flush()
|
|
by_name.update({t.name: t for t in to_create})
|
|
return [by_name[n] for n in wanted]
|
|
|
|
|
|
def _apply_tag_filters(
|
|
stmt: sa.sql.Select,
|
|
include_tags: Optional[Sequence[str]],
|
|
exclude_tags: Optional[Sequence[str]],
|
|
) -> sa.sql.Select:
|
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
include_tags = normalize_tags(include_tags)
|
|
exclude_tags = normalize_tags(exclude_tags)
|
|
|
|
if include_tags:
|
|
for tag_name in include_tags:
|
|
stmt = stmt.where(
|
|
exists().where(
|
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
|
& (AssetInfoTag.tag_name == tag_name)
|
|
)
|
|
)
|
|
|
|
if exclude_tags:
|
|
stmt = stmt.where(
|
|
~exists().where(
|
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
|
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
|
)
|
|
)
|
|
return stmt
|
|
|
|
def _apply_metadata_filter(
|
|
stmt: sa.sql.Select,
|
|
metadata_filter: Optional[dict],
|
|
) -> sa.sql.Select:
|
|
"""Apply metadata filters using the projection table asset_info_meta.
|
|
|
|
Semantics:
|
|
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
|
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
|
- For list values: ANY-of the list elements matches (EXISTS for any).
|
|
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
|
"""
|
|
if not metadata_filter:
|
|
return stmt
|
|
|
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
subquery = (
|
|
select(sa.literal(1))
|
|
.select_from(AssetInfoMeta)
|
|
.where(
|
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
|
AssetInfoMeta.key == key,
|
|
*preds,
|
|
)
|
|
.limit(1)
|
|
)
|
|
return sa.exists(subquery)
|
|
|
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
# Missing OR null:
|
|
if value is None:
|
|
# either: no row for key OR a row for key with explicit null
|
|
no_row_for_key = ~sa.exists(
|
|
select(sa.literal(1))
|
|
.select_from(AssetInfoMeta)
|
|
.where(
|
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
|
AssetInfoMeta.key == key,
|
|
)
|
|
.limit(1)
|
|
)
|
|
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
|
|
return sa.or_(no_row_for_key, null_row)
|
|
|
|
# Typed scalar matches:
|
|
if isinstance(value, bool):
|
|
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
|
if isinstance(value, (int, float, Decimal)):
|
|
# store as Decimal for equality against NUMERIC(38,10)
|
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
|
if isinstance(value, str):
|
|
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
|
|
|
# Complex: compare JSON (no index, but supported)
|
|
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
|
|
|
for k, v in metadata_filter.items():
|
|
if isinstance(v, list):
|
|
# ANY-of (exists for any element)
|
|
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
|
if ors:
|
|
stmt = stmt.where(sa.or_(*ors))
|
|
else:
|
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
return stmt
|
|
|
|
|
|
def _is_scalar(v: Any) -> bool:
|
|
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
|
return True
|
|
if isinstance(v, bool):
|
|
return True
|
|
if isinstance(v, (int, float, Decimal, str)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _project_kv(key: str, value: Any) -> list[dict]:
|
|
"""
|
|
Turn a metadata key/value into one or more projection rows:
|
|
- scalar -> one row (ordinal=0) in the proper typed column
|
|
- list of scalars -> one row per element with ordinal=i
|
|
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
|
- None -> single row with val_json = None
|
|
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
|
"""
|
|
rows: list[dict] = []
|
|
|
|
# None
|
|
if value is None:
|
|
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
|
return rows
|
|
|
|
# Scalars
|
|
if _is_scalar(value):
|
|
if isinstance(value, bool):
|
|
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
|
elif isinstance(value, (int, float, Decimal)):
|
|
# store numeric; SQLAlchemy will coerce to Numeric
|
|
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
|
elif isinstance(value, str):
|
|
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
|
else:
|
|
# Fallback to json
|
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
|
return rows
|
|
|
|
# Lists
|
|
if isinstance(value, list):
|
|
# list of scalars?
|
|
if all(_is_scalar(x) for x in value):
|
|
for i, x in enumerate(value):
|
|
if x is None:
|
|
rows.append({"key": key, "ordinal": i, "val_json": None})
|
|
elif isinstance(x, bool):
|
|
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
|
elif isinstance(x, (int, float, Decimal)):
|
|
rows.append({"key": key, "ordinal": i, "val_num": x})
|
|
elif isinstance(x, str):
|
|
rows.append({"key": key, "ordinal": i, "val_str": x})
|
|
else:
|
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
|
return rows
|
|
# list contains objects -> one val_json per element
|
|
for i, x in enumerate(value):
|
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
|
return rows
|
|
|
|
# Dict or any other structure -> single json row
|
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
|
return rows
|