diff --git a/app/database/services/content.py b/app/database/services/content.py index 546cc7bd1..ead2e2389 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -513,14 +513,20 @@ async def ingest_fs_asset( "created_at": now, } if dialect == "sqlite": - ins = ( + res = await session.execute( d_sqlite.insert(Asset) .values(**vals) .on_conflict_do_nothing(index_elements=[Asset.hash]) - .returning(Asset.id) ) + if int(res.rowcount or 0) > 0: + out["asset_created"] = True + asset = ( + await session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() elif dialect == "postgresql": - ins = ( + res = await session.execute( d_pg.insert(Asset) .values(**vals) .on_conflict_do_nothing( @@ -529,24 +535,20 @@ async def ingest_fs_asset( ) .returning(Asset.id) ) + inserted_id = res.scalar_one_or_none() + if inserted_id: + out["asset_created"] = True + asset = await session.get(Asset, inserted_id) + else: + asset = ( + await session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() else: raise NotImplementedError(f"Unsupported database dialect: {dialect}") - res = await session.execute(ins) - inserted_id = res.scalar_one_or_none() - asset = ( - await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() if not asset: raise RuntimeError("Asset row not found after upsert.") - if inserted_id: - out["asset_created"] = True - asset = await session.get(Asset, inserted_id) - else: - asset = ( - await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - if not asset: - raise RuntimeError("Asset row not found after upsert.") else: changed = False if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: diff --git a/app/database/services/info.py b/app/database/services/info.py index 5c7e3c92f..d2fd1f503 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -377,17 +377,20 @@ async def touch_asset_info_by_id( stmt = stmt.where( sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) ) - stmt = stmt.values(last_access_time=ts).returning(AssetInfo.id) - return (await session.execute(stmt)).scalar_one_or_none() is not None + stmt = stmt.values(last_access_time=ts) + if session.bind.dialect.name == "postgresql": + return (await session.execute(stmt.returning(AssetInfo.id))).scalar_one_or_none() is not None + return int((await session.execute(stmt)).rowcount or 0) > 0 async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: - return ( - await session.execute(delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ).returning(AssetInfo.id)) - ).scalar_one_or_none() is not None + stmt = sa.delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + if session.bind.dialect.name == "postgresql": + return (await session.execute(stmt.returning(AssetInfo.id))).scalar_one_or_none() is not None + return int((await session.execute(stmt)).rowcount or 0) > 0 async def add_tags_to_asset_info(