mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 02:53:05 +08:00
optimization: DB Queries (Tags)
This commit is contained in:
parent
7becb84341
commit
025fc49b4e
@ -1,6 +1,6 @@
|
|||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from sqlalchemy import delete, select
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.dialects import postgresql as d_pg
|
from sqlalchemy.dialects import postgresql as d_pg
|
||||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@ -10,34 +10,27 @@ from ..models import AssetInfo, AssetInfoTag, Tag
|
|||||||
from ..timeutil import utcnow
|
from ..timeutil import utcnow
|
||||||
|
|
||||||
|
|
||||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> None:
|
||||||
wanted = normalize_tags(list(names))
|
wanted = normalize_tags(list(names))
|
||||||
if not wanted:
|
if not wanted:
|
||||||
return []
|
return
|
||||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||||
existing_names = {t.name for t in existing}
|
dialect = session.bind.dialect.name
|
||||||
missing = [n for n in wanted if n not in existing_names]
|
if dialect == "sqlite":
|
||||||
if missing:
|
ins = (
|
||||||
dialect = session.bind.dialect.name
|
d_sqlite.insert(Tag)
|
||||||
rows = [{"name": n, "tag_type": tag_type} for n in missing]
|
.values(rows)
|
||||||
if dialect == "sqlite":
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
ins = (
|
)
|
||||||
d_sqlite.insert(Tag)
|
elif dialect == "postgresql":
|
||||||
.values(rows)
|
ins = (
|
||||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
d_pg.insert(Tag)
|
||||||
)
|
.values(rows)
|
||||||
elif dialect == "postgresql":
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
ins = (
|
)
|
||||||
d_pg.insert(Tag)
|
else:
|
||||||
.values(rows)
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
await session.execute(ins)
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
|
||||||
await session.execute(ins)
|
|
||||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
|
||||||
by_name = {t.name: t for t in existing}
|
|
||||||
return [by_name[n] for n in wanted if n in by_name]
|
|
||||||
|
|
||||||
|
|
||||||
async def add_missing_tag_for_asset_id(
|
async def add_missing_tag_for_asset_id(
|
||||||
@ -45,53 +38,53 @@ async def add_missing_tag_for_asset_id(
|
|||||||
*,
|
*,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
origin: str = "automatic",
|
origin: str = "automatic",
|
||||||
) -> int:
|
) -> None:
|
||||||
"""Ensure every AssetInfo for asset_id has 'missing' tag."""
|
select_rows = (
|
||||||
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
sa.select(
|
||||||
if not ids:
|
AssetInfo.id.label("asset_info_id"),
|
||||||
return 0
|
sa.literal("missing").label("tag_name"),
|
||||||
|
sa.literal(origin).label("origin"),
|
||||||
existing = {
|
sa.literal(utcnow()).label("added_at"),
|
||||||
asset_info_id
|
)
|
||||||
for (asset_info_id,) in (
|
.where(AssetInfo.asset_id == asset_id)
|
||||||
await session.execute(
|
.where(
|
||||||
select(AssetInfoTag.asset_info_id).where(
|
sa.not_(
|
||||||
AssetInfoTag.asset_info_id.in_(ids),
|
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||||
AssetInfoTag.tag_name == "missing",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
).all()
|
)
|
||||||
}
|
|
||||||
to_add = [i for i in ids if i not in existing]
|
|
||||||
if not to_add:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
now = utcnow()
|
|
||||||
session.add_all(
|
|
||||||
[
|
|
||||||
AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now)
|
|
||||||
for i in to_add
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
await session.flush()
|
dialect = session.bind.dialect.name
|
||||||
return len(to_add)
|
if dialect == "sqlite":
|
||||||
|
ins = (
|
||||||
|
d_sqlite.insert(AssetInfoTag)
|
||||||
|
.from_select(
|
||||||
|
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||||
|
select_rows,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||||
|
)
|
||||||
|
elif dialect == "postgresql":
|
||||||
|
ins = (
|
||||||
|
d_pg.insert(AssetInfoTag)
|
||||||
|
.from_select(
|
||||||
|
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||||
|
select_rows,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||||
|
await session.execute(ins)
|
||||||
|
|
||||||
|
|
||||||
async def remove_missing_tag_for_asset_id(
|
async def remove_missing_tag_for_asset_id(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
) -> int:
|
) -> None:
|
||||||
"""Remove the 'missing' tag from all AssetInfos for asset_id."""
|
await session.execute(
|
||||||
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
sa.delete(AssetInfoTag).where(
|
||||||
if not ids:
|
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||||
return 0
|
|
||||||
|
|
||||||
res = await session.execute(
|
|
||||||
delete(AssetInfoTag).where(
|
|
||||||
AssetInfoTag.asset_info_id.in_(ids),
|
|
||||||
AssetInfoTag.tag_name == "missing",
|
AssetInfoTag.tag_name == "missing",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.flush()
|
|
||||||
return int(res.rowcount or 0)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user