optimization: DB Queries (Tags)

This commit is contained in:
bigcat88 2025-09-15 10:26:13 +03:00
parent 7becb84341
commit 025fc49b4e
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -1,6 +1,6 @@
from typing import Iterable from typing import Iterable
from sqlalchemy import delete, select import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -10,34 +10,27 @@ from ..models import AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow from ..timeutil import utcnow
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names)) wanted = normalize_tags(list(names))
if not wanted: if not wanted:
return [] return
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
existing_names = {t.name for t in existing} dialect = session.bind.dialect.name
missing = [n for n in wanted if n not in existing_names] if dialect == "sqlite":
if missing: ins = (
dialect = session.bind.dialect.name d_sqlite.insert(Tag)
rows = [{"name": n, "tag_type": tag_type} for n in missing] .values(rows)
if dialect == "sqlite": .on_conflict_do_nothing(index_elements=[Tag.name])
ins = ( )
d_sqlite.insert(Tag) elif dialect == "postgresql":
.values(rows) ins = (
.on_conflict_do_nothing(index_elements=[Tag.name]) d_pg.insert(Tag)
) .values(rows)
elif dialect == "postgresql": .on_conflict_do_nothing(index_elements=[Tag.name])
ins = ( )
d_pg.insert(Tag) else:
.values(rows) raise NotImplementedError(f"Unsupported database dialect: {dialect}")
.on_conflict_do_nothing(index_elements=[Tag.name]) await session.execute(ins)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
by_name = {t.name: t for t in existing}
return [by_name[n] for n in wanted if n in by_name]
async def add_missing_tag_for_asset_id( async def add_missing_tag_for_asset_id(
@ -45,53 +38,53 @@ async def add_missing_tag_for_asset_id(
*, *,
asset_id: str, asset_id: str,
origin: str = "automatic", origin: str = "automatic",
) -> int: ) -> None:
"""Ensure every AssetInfo for asset_id has 'missing' tag.""" select_rows = (
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() sa.select(
if not ids: AssetInfo.id.label("asset_info_id"),
return 0 sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
existing = { sa.literal(utcnow()).label("added_at"),
asset_info_id )
for (asset_info_id,) in ( .where(AssetInfo.asset_id == asset_id)
await session.execute( .where(
select(AssetInfoTag.asset_info_id).where( sa.not_(
AssetInfoTag.asset_info_id.in_(ids), sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
AssetInfoTag.tag_name == "missing",
)
) )
).all() )
}
to_add = [i for i in ids if i not in existing]
if not to_add:
return 0
now = utcnow()
session.add_all(
[
AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now)
for i in to_add
]
) )
await session.flush() dialect = session.bind.dialect.name
return len(to_add) if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
await session.execute(ins)
async def remove_missing_tag_for_asset_id( async def remove_missing_tag_for_asset_id(
session: AsyncSession, session: AsyncSession,
*, *,
asset_id: str, asset_id: str,
) -> int: ) -> None:
"""Remove the 'missing' tag from all AssetInfos for asset_id.""" await session.execute(
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() sa.delete(AssetInfoTag).where(
if not ids: AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
return 0
res = await session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(ids),
AssetInfoTag.tag_name == "missing", AssetInfoTag.tag_name == "missing",
) )
) )
await session.flush()
return int(res.rowcount or 0)