mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
268 lines
9.0 KiB
Python
268 lines
9.0 KiB
Python
import sqlalchemy as sa
|
|
from collections import defaultdict
|
|
from sqlalchemy import select, exists, func
|
|
from sqlalchemy.orm import Session, contains_eager, noload
|
|
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
|
from app.assets.helpers import escape_like_prefix, normalize_tags
|
|
from typing import Sequence
|
|
|
|
|
|
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
|
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
|
owner_id = (owner_id or "").strip()
|
|
if owner_id == "":
|
|
return AssetInfo.owner_id == ""
|
|
return AssetInfo.owner_id.in_(["", owner_id])
|
|
|
|
|
|
def apply_tag_filters(
|
|
stmt: sa.sql.Select,
|
|
include_tags: Sequence[str] | None = None,
|
|
exclude_tags: Sequence[str] | None = None,
|
|
) -> sa.sql.Select:
|
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
include_tags = normalize_tags(include_tags)
|
|
exclude_tags = normalize_tags(exclude_tags)
|
|
|
|
if include_tags:
|
|
for tag_name in include_tags:
|
|
stmt = stmt.where(
|
|
exists().where(
|
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
|
& (AssetInfoTag.tag_name == tag_name)
|
|
)
|
|
)
|
|
|
|
if exclude_tags:
|
|
stmt = stmt.where(
|
|
~exists().where(
|
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
|
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
|
)
|
|
)
|
|
return stmt
|
|
|
|
def apply_metadata_filter(
|
|
stmt: sa.sql.Select,
|
|
metadata_filter: dict | None = None,
|
|
) -> sa.sql.Select:
|
|
"""Apply filters using asset_info_meta projection table."""
|
|
if not metadata_filter:
|
|
return stmt
|
|
|
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
return sa.exists().where(
|
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
|
AssetInfoMeta.key == key,
|
|
*preds,
|
|
)
|
|
|
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
if value is None:
|
|
no_row_for_key = sa.not_(
|
|
sa.exists().where(
|
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
|
AssetInfoMeta.key == key,
|
|
)
|
|
)
|
|
null_row = _exists_for_pred(
|
|
key,
|
|
AssetInfoMeta.val_json.is_(None),
|
|
AssetInfoMeta.val_str.is_(None),
|
|
AssetInfoMeta.val_num.is_(None),
|
|
AssetInfoMeta.val_bool.is_(None),
|
|
)
|
|
return sa.or_(no_row_for_key, null_row)
|
|
|
|
if isinstance(value, bool):
|
|
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
|
if isinstance(value, (int, float)):
|
|
from decimal import Decimal
|
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
|
if isinstance(value, str):
|
|
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
|
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
|
|
|
for k, v in metadata_filter.items():
|
|
if isinstance(v, list):
|
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
if ors:
|
|
stmt = stmt.where(sa.or_(*ors))
|
|
else:
|
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
return stmt
|
|
|
|
|
|
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
|
"""
|
|
Check if an asset with a given hash exists in database.
|
|
"""
|
|
row = (
|
|
session.execute(
|
|
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
|
)
|
|
).first()
|
|
return row is not None
|
|
|
|
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
|
return session.get(AssetInfo, asset_info_id)
|
|
|
|
def list_asset_infos_page(
|
|
session: Session,
|
|
owner_id: str = "",
|
|
include_tags: Sequence[str] | None = None,
|
|
exclude_tags: Sequence[str] | None = None,
|
|
name_contains: str | None = None,
|
|
metadata_filter: dict | None = None,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
sort: str = "created_at",
|
|
order: str = "desc",
|
|
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
|
base = (
|
|
select(AssetInfo)
|
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
|
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
|
.where(visible_owner_clause(owner_id))
|
|
)
|
|
|
|
if name_contains:
|
|
escaped, esc = escape_like_prefix(name_contains)
|
|
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
|
|
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
|
base = apply_metadata_filter(base, metadata_filter)
|
|
|
|
sort = (sort or "created_at").lower()
|
|
order = (order or "desc").lower()
|
|
sort_map = {
|
|
"name": AssetInfo.name,
|
|
"created_at": AssetInfo.created_at,
|
|
"updated_at": AssetInfo.updated_at,
|
|
"last_access_time": AssetInfo.last_access_time,
|
|
"size": Asset.size_bytes,
|
|
}
|
|
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
|
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
|
|
|
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
|
|
|
count_stmt = (
|
|
select(sa.func.count())
|
|
.select_from(AssetInfo)
|
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
|
.where(visible_owner_clause(owner_id))
|
|
)
|
|
if name_contains:
|
|
escaped, esc = escape_like_prefix(name_contains)
|
|
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
|
|
|
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
|
|
|
infos = (session.execute(base)).unique().scalars().all()
|
|
|
|
id_list: list[str] = [i.id for i in infos]
|
|
tag_map: dict[str, list[str]] = defaultdict(list)
|
|
if id_list:
|
|
rows = session.execute(
|
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
|
)
|
|
for aid, tag_name in rows.all():
|
|
tag_map[aid].append(tag_name)
|
|
|
|
return infos, tag_map, total
|
|
|
|
def fetch_asset_info_asset_and_tags(
|
|
session: Session,
|
|
asset_info_id: str,
|
|
owner_id: str = "",
|
|
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
|
stmt = (
|
|
select(AssetInfo, Asset, Tag.name)
|
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
|
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
|
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
|
.where(
|
|
AssetInfo.id == asset_info_id,
|
|
visible_owner_clause(owner_id),
|
|
)
|
|
.options(noload(AssetInfo.tags))
|
|
.order_by(Tag.name.asc())
|
|
)
|
|
|
|
rows = (session.execute(stmt)).all()
|
|
if not rows:
|
|
return None
|
|
|
|
first_info, first_asset, _ = rows[0]
|
|
tags: list[str] = []
|
|
seen: set[str] = set()
|
|
for _info, _asset, tag_name in rows:
|
|
if tag_name and tag_name not in seen:
|
|
seen.add(tag_name)
|
|
tags.append(tag_name)
|
|
return first_info, first_asset, tags
|
|
|
|
def list_tags_with_usage(
|
|
session: Session,
|
|
prefix: str | None = None,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
include_zero: bool = True,
|
|
order: str = "count_desc",
|
|
owner_id: str = "",
|
|
) -> tuple[list[tuple[str, str, int]], int]:
|
|
counts_sq = (
|
|
select(
|
|
AssetInfoTag.tag_name.label("tag_name"),
|
|
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
|
)
|
|
.select_from(AssetInfoTag)
|
|
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
|
.where(visible_owner_clause(owner_id))
|
|
.group_by(AssetInfoTag.tag_name)
|
|
.subquery()
|
|
)
|
|
|
|
q = (
|
|
select(
|
|
Tag.name,
|
|
Tag.tag_type,
|
|
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
|
)
|
|
.select_from(Tag)
|
|
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
|
)
|
|
|
|
if prefix:
|
|
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
|
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
|
|
|
if not include_zero:
|
|
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
|
|
|
if order == "name_asc":
|
|
q = q.order_by(Tag.name.asc())
|
|
else:
|
|
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
|
|
|
total_q = select(func.count()).select_from(Tag)
|
|
if prefix:
|
|
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
|
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
|
if not include_zero:
|
|
total_q = total_q.where(
|
|
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
|
)
|
|
|
|
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
|
total = (session.execute(total_q)).scalar_one()
|
|
|
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
|
return rows_norm, int(total or 0)
|