diff --git a/app/model_downloader/database/queries.py b/app/model_downloader/database/queries.py index 6b5cbebe0..9216c388a 100644 --- a/app/model_downloader/database/queries.py +++ b/app/model_downloader/database/queries.py @@ -11,6 +11,7 @@ import time from typing import Optional from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from app.database.db import create_session from app.model_downloader.constants import DownloadStatus @@ -199,30 +200,48 @@ def list_subdomain_credentials() -> list[HostCredential]: def upsert_credential(values: dict) -> HostCredential: - """Insert or update a credential keyed by ``host``.""" + """Insert or update a credential keyed by ``host``. + + Callers can target the same host concurrently (each runs in its own + short-lived session on a separate connection), so the read-then-write here + can race: two callers both see no existing row and both attempt an insert. + The ``host`` column is uniquely indexed, so the loser's insert raises + ``IntegrityError``. We recover by rolling back and retrying, at which point + the now-committed row is found and updated in place, letting concurrent + calls converge instead of failing or creating duplicates. + """ host = values["host"] now = int(time.time()) - with create_session() as session: - row = ( - session.execute( - select(HostCredential).where(HostCredential.host == host).limit(1) + last_error: IntegrityError | None = None + for _ in range(2): + with create_session() as session: + row = ( + session.execute( + select(HostCredential).where(HostCredential.host == host).limit(1) + ) + .scalars() + .first() ) - .scalars() - .first() - ) - if row is None: - row = HostCredential(**values) - row.created_at = now - row.updated_at = now - session.add(row) - else: - for key, value in values.items(): - setattr(row, key, value) - row.updated_at = now - session.commit() - session.refresh(row) - session.expunge(row) - return row + if row is None: + row = HostCredential(**values) + row.created_at = now + row.updated_at = now + session.add(row) + else: + for key, value in values.items(): + setattr(row, key, value) + row.updated_at = now + try: + session.commit() + except IntegrityError as exc: + session.rollback() + last_error = exc + continue + session.refresh(row) + session.expunge(row) + return row + assert last_error is not None + raise last_error def delete_credential(credential_id: str) -> bool: