From c78513022375cc24a9f43f4009df17ac3d59eb84 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 30 Jun 2026 10:50:54 +0200 Subject: [PATCH] Switch to asyncio.to_thread for db calls in job.py --- app/model_downloader/engine/job.py | 68 +++++++++++++++++++----------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/app/model_downloader/engine/job.py b/app/model_downloader/engine/job.py index 499e41c81..0b1519ce9 100644 --- a/app/model_downloader/engine/job.py +++ b/app/model_downloader/engine/job.py @@ -149,38 +149,38 @@ class DownloadJob: async def run(self) -> str: """Run to a terminal/paused state; returns the final status string.""" - self._set_status(DownloadStatus.ACTIVE, error=None) + await self._set_status(DownloadStatus.ACTIVE, error=None) try: pr = await self._probe_and_plan() await self._transfer(pr) await self._finalize() - self._set_status(DownloadStatus.COMPLETED) + await self._set_status(DownloadStatus.COMPLETED) except Paused: await self._persist_progress(force=True) - self._set_status(DownloadStatus.PAUSED) + await self._set_status(DownloadStatus.PAUSED) except Cancelled: await self._close_writer() self._remove_temp() - self._set_status(DownloadStatus.CANCELLED) + await self._set_status(DownloadStatus.CANCELLED) except RemoteChanged: await self._reset_for_restart() - self._set_status( + await self._set_status( DownloadStatus.QUEUED, error="remote file changed; restarting" ) except RetryableError as e: await self._persist_progress(force=True) - self._set_status(DownloadStatus.QUEUED, error=str(e)) + await self._set_status(DownloadStatus.QUEUED, error=str(e)) except FatalError as e: await self._close_writer() self._remove_temp() - self._set_status(DownloadStatus.FAILED, error=str(e)) + await self._set_status(DownloadStatus.FAILED, error=str(e)) except Exception as e: # unexpected -> treat as retryable logging.warning( "[model_downloader] %s unexpected error: %s", self.spec.model_id, e, exc_info=True, ) await self._persist_progress(force=True) - self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}") + await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}") finally: await self._close_writer() return self.state.status @@ -208,7 +208,8 @@ class DownloadJob: self._etag = pr.etag or self._etag self.state.total_bytes = pr.total_bytes - queries.update_download( + await asyncio.to_thread( + queries.update_download, self.spec.download_id, final_url=pr.final_url, total_bytes=pr.total_bytes, @@ -220,7 +221,7 @@ class DownloadJob: seg_count = effective_segment_count( pr.total_bytes, pr.accept_ranges, max(1, args.download_segments) ) - existing = queries.list_segments(self.spec.download_id) + existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id) if ( seg_count > 1 and existing @@ -234,7 +235,8 @@ class DownloadJob: ] elif seg_count > 1 and pr.total_bytes is not None: plans = plan_segments(pr.total_bytes, seg_count) - queries.replace_segments( + await asyncio.to_thread( + queries.replace_segments, self.spec.download_id, [ {"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0} @@ -244,7 +246,7 @@ class DownloadJob: self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans] else: # Single-stream: one logical segment; bytes_done tracked on the row. - row = queries.get_download(self.spec.download_id) + row = await asyncio.to_thread(queries.get_download, self.spec.download_id) resume_from = row.bytes_done if row else 0 end = (pr.total_bytes - 1) if pr.total_bytes else -1 # ``row.bytes_done`` may be the SUM of per-segment offsets from a @@ -263,9 +265,13 @@ class DownloadJob: ) resume_from = 0 self._remove_temp() - if queries.list_segments(self.spec.download_id): - queries.replace_segments(self.spec.download_id, []) - queries.update_download(self.spec.download_id, bytes_done=0) + if await asyncio.to_thread(queries.list_segments, self.spec.download_id): + await asyncio.to_thread( + queries.replace_segments, self.spec.download_id, [] + ) + await asyncio.to_thread( + queries.update_download, self.spec.download_id, bytes_done=0 + ) self.state.segments = [SegmentRuntime(0, 0, end, resume_from)] self._recompute_bytes_done() return pr @@ -419,7 +425,7 @@ class DownloadJob: async def _finalize(self) -> None: self._check_control() await self._close_writer() - self._set_status(DownloadStatus.VERIFYING) + await self._set_status(DownloadStatus.VERIFYING) total = self.state.total_bytes segmented = len(self.state.segments) > 1 @@ -486,20 +492,28 @@ class DownloadJob: self.state._last_time = now async def _persist_progress(self, force: bool = False) -> None: + # Both the DB write and the websocket notify are gated by the same + # throttle: persisting hits SQLite, and notifying broadcasts to every + # client, so doing either per-chunk (small --download-chunk-size or + # many concurrent segments) would overwhelm both. Skip entirely inside + # the window; the next persist (or a forced one) ships the latest bytes. now = time.monotonic() if not force and now - self._last_persist < _PERSIST_INTERVAL: - if self._notify: - self._notify(self.spec.download_id) return self._last_persist = now + # SQLite is blocking; run it off the event loop per the queries module + # contract so progress persists don't stall the web server. + await asyncio.to_thread(self._write_progress) + if self._notify: + self._notify(self.spec.download_id) + + def _write_progress(self) -> None: queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done) for seg in self.state.segments: if seg.end >= seg.start: # skip unknown-size sentinel queries.update_segment_progress( self.spec.download_id, seg.idx, seg.bytes_done ) - if self._notify: - self._notify(self.spec.download_id) async def _reset_for_restart(self) -> None: await self._close_writer() @@ -507,9 +521,13 @@ class DownloadJob: for seg in self.state.segments: seg.bytes_done = 0 self.state.bytes_done = 0 - queries.update_download(self.spec.download_id, bytes_done=0) - if queries.list_segments(self.spec.download_id): - queries.replace_segments(self.spec.download_id, []) + await asyncio.to_thread( + queries.update_download, self.spec.download_id, bytes_done=0 + ) + if await asyncio.to_thread(queries.list_segments, self.spec.download_id): + await asyncio.to_thread( + queries.replace_segments, self.spec.download_id, [] + ) async def _close_writer(self) -> None: if self._writer is not None: @@ -543,7 +561,7 @@ class DownloadJob: "[model_downloader] could not remove %s: %s", self.spec.temp_path, e ) - def _set_status(self, status: str, error: Optional[str] = None) -> None: + async def _set_status(self, status: str, error: Optional[str] = None) -> None: # ``error`` is authoritative: passing None clears any prior failure # text so transitions out of a failure state (retry/success) don't # leave stale messages on RuntimeState or in the persisted row. @@ -553,6 +571,6 @@ class DownloadJob: if status == DownloadStatus.QUEUED: fields["attempts"] = self.spec.attempts + 1 self.spec.attempts += 1 - queries.update_download(self.spec.download_id, **fields) + await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields) if self._notify: self._notify(self.spec.download_id)