diff --git a/app/database/db.py b/app/database/db.py index 2b09b8147..c0342e75f 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -4,10 +4,7 @@ import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message from filelock import FileLock, Timeout -# NOTE: import the module (not `from ... import args`) so we always read the -# live `args` object. Tests reload `comfy.cli_args`, which replaces the module -# global; a bound `args` reference would go stale and point at the default -# database URL instead of the one configured for the test. +# Import the module so tests that reload comfy.cli_args see the live object. import comfy.cli_args _DB_AVAILABLE = False diff --git a/app/model_downloader/database/queries.py b/app/model_downloader/database/queries.py index c71a234e1..e4b447841 100644 --- a/app/model_downloader/database/queries.py +++ b/app/model_downloader/database/queries.py @@ -50,6 +50,19 @@ def list_downloads() -> list[Download]: return rows +def has_live_download_for_model( + model_id: str, live_statuses: tuple[str, ...], exclude_id: Optional[str] = None +) -> bool: + with create_session() as session: + stmt = select(Download.id).where( + Download.model_id == model_id, + Download.status.in_(live_statuses), + ).limit(1) + if exclude_id is not None: + stmt = stmt.where(Download.id != exclude_id) + return session.execute(stmt).first() is not None + + def list_segments(download_id: str) -> list[DownloadSegment]: with create_session() as session: rows = list( diff --git a/app/model_downloader/manager.py b/app/model_downloader/manager.py index 60f597237..46d663235 100644 --- a/app/model_downloader/manager.py +++ b/app/model_downloader/manager.py @@ -191,12 +191,8 @@ class DownloadManager: async def _has_live_download( self, model_id: str, *, exclude_id: Optional[str] = None ) -> bool: - rows = await asyncio.to_thread(queries.list_downloads) - return any( - r.model_id == model_id - and r.id != exclude_id - and r.status in _LIVE_STATUSES - for r in rows + return await asyncio.to_thread( + queries.has_live_download_for_model, model_id, _LIVE_STATUSES, exclude_id ) # ----- control ----- @@ -249,8 +245,6 @@ class DownloadManager: if row is None: raise DownloadError("NOT_FOUND", "No such download.", status=404) if row.status in _LIVE_STATUSES: - import os - try: os.remove(row.temp_path) except OSError: