From 025fc49b4e667f53ece6978b61d9d8bf093983ee Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Mon, 15 Sep 2025 10:26:13 +0300 Subject: [PATCH] optimization: DB Queries (Tags) --- app/database/helpers/tags.py | 125 +++++++++++++++++------------------ 1 file changed, 59 insertions(+), 66 deletions(-) diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index b3e006c0e..058869eca 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -1,6 +1,6 @@ from typing import Iterable -from sqlalchemy import delete, select +import sqlalchemy as sa from sqlalchemy.dialects import postgresql as d_pg from sqlalchemy.dialects import sqlite as d_sqlite from sqlalchemy.ext.asyncio import AsyncSession @@ -10,34 +10,27 @@ from ..models import AssetInfo, AssetInfoTag, Tag from ..timeutil import utcnow -async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None: wanted = normalize_tags(list(names)) if not wanted: - return [] - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - existing_names = {t.name for t in existing} - missing = [n for n in wanted if n not in existing_names] - if missing: - dialect = session.bind.dialect.name - rows = [{"name": n, "tag_type": tag_type} for n in missing] - if dialect == "sqlite": - ins = ( - d_sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - elif dialect == "postgresql": - ins = ( - d_pg.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - else: - raise NotImplementedError(f"Unsupported database dialect: {dialect}") - await session.execute(ins) - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - by_name = {t.name: t for t in existing} - return [by_name[n] for n in wanted if n in by_name] + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) async def add_missing_tag_for_asset_id( @@ -45,53 +38,53 @@ async def add_missing_tag_for_asset_id( *, asset_id: str, origin: str = "automatic", -) -> int: - """Ensure every AssetInfo for asset_id has 'missing' tag.""" - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() - if not ids: - return 0 - - existing = { - asset_info_id - for (asset_info_id,) in ( - await session.execute( - select(AssetInfoTag.asset_info_id).where( - AssetInfoTag.asset_info_id.in_(ids), - AssetInfoTag.tag_name == "missing", - ) +) -> None: + select_rows = ( + sa.select( + AssetInfo.id.label("asset_info_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) ) - ).all() - } - to_add = [i for i in ids if i not in existing] - if not to_add: - return 0 - - now = utcnow() - session.add_all( - [ - AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) - for i in to_add - ] + ) ) - await session.flush() - return len(to_add) + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) async def remove_missing_tag_for_asset_id( session: AsyncSession, *, asset_id: str, -) -> int: - """Remove the 'missing' tag from all AssetInfos for asset_id.""" - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() - if not ids: - return 0 - - res = await session.execute( - delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(ids), +) -> None: + await session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), AssetInfoTag.tag_name == "missing", ) ) - await session.flush() - return int(res.rowcount or 0)