Switch to asyncio.to_thread for db calls in job.py

This commit is contained in:
Talmaj Marinc 2026-06-30 10:50:54 +02:00
parent 4ccaaa6f37
commit c785130223

View File

@ -149,38 +149,38 @@ class DownloadJob:
async def run(self) -> str:
"""Run to a terminal/paused state; returns the final status string."""
self._set_status(DownloadStatus.ACTIVE, error=None)
await self._set_status(DownloadStatus.ACTIVE, error=None)
try:
pr = await self._probe_and_plan()
await self._transfer(pr)
await self._finalize()
self._set_status(DownloadStatus.COMPLETED)
await self._set_status(DownloadStatus.COMPLETED)
except Paused:
await self._persist_progress(force=True)
self._set_status(DownloadStatus.PAUSED)
await self._set_status(DownloadStatus.PAUSED)
except Cancelled:
await self._close_writer()
self._remove_temp()
self._set_status(DownloadStatus.CANCELLED)
await self._set_status(DownloadStatus.CANCELLED)
except RemoteChanged:
await self._reset_for_restart()
self._set_status(
await self._set_status(
DownloadStatus.QUEUED, error="remote file changed; restarting"
)
except RetryableError as e:
await self._persist_progress(force=True)
self._set_status(DownloadStatus.QUEUED, error=str(e))
await self._set_status(DownloadStatus.QUEUED, error=str(e))
except FatalError as e:
await self._close_writer()
self._remove_temp()
self._set_status(DownloadStatus.FAILED, error=str(e))
await self._set_status(DownloadStatus.FAILED, error=str(e))
except Exception as e: # unexpected -> treat as retryable
logging.warning(
"[model_downloader] %s unexpected error: %s",
self.spec.model_id, e, exc_info=True,
)
await self._persist_progress(force=True)
self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
finally:
await self._close_writer()
return self.state.status
@ -208,7 +208,8 @@ class DownloadJob:
self._etag = pr.etag or self._etag
self.state.total_bytes = pr.total_bytes
queries.update_download(
await asyncio.to_thread(
queries.update_download,
self.spec.download_id,
final_url=pr.final_url,
total_bytes=pr.total_bytes,
@ -220,7 +221,7 @@ class DownloadJob:
seg_count = effective_segment_count(
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
)
existing = queries.list_segments(self.spec.download_id)
existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id)
if (
seg_count > 1
and existing
@ -234,7 +235,8 @@ class DownloadJob:
]
elif seg_count > 1 and pr.total_bytes is not None:
plans = plan_segments(pr.total_bytes, seg_count)
queries.replace_segments(
await asyncio.to_thread(
queries.replace_segments,
self.spec.download_id,
[
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
@ -244,7 +246,7 @@ class DownloadJob:
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
else:
# Single-stream: one logical segment; bytes_done tracked on the row.
row = queries.get_download(self.spec.download_id)
row = await asyncio.to_thread(queries.get_download, self.spec.download_id)
resume_from = row.bytes_done if row else 0
end = (pr.total_bytes - 1) if pr.total_bytes else -1
# ``row.bytes_done`` may be the SUM of per-segment offsets from a
@ -263,9 +265,13 @@ class DownloadJob:
)
resume_from = 0
self._remove_temp()
if queries.list_segments(self.spec.download_id):
queries.replace_segments(self.spec.download_id, [])
queries.update_download(self.spec.download_id, bytes_done=0)
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
await asyncio.to_thread(
queries.replace_segments, self.spec.download_id, []
)
await asyncio.to_thread(
queries.update_download, self.spec.download_id, bytes_done=0
)
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
self._recompute_bytes_done()
return pr
@ -419,7 +425,7 @@ class DownloadJob:
async def _finalize(self) -> None:
self._check_control()
await self._close_writer()
self._set_status(DownloadStatus.VERIFYING)
await self._set_status(DownloadStatus.VERIFYING)
total = self.state.total_bytes
segmented = len(self.state.segments) > 1
@ -486,20 +492,28 @@ class DownloadJob:
self.state._last_time = now
async def _persist_progress(self, force: bool = False) -> None:
# Both the DB write and the websocket notify are gated by the same
# throttle: persisting hits SQLite, and notifying broadcasts to every
# client, so doing either per-chunk (small --download-chunk-size or
# many concurrent segments) would overwhelm both. Skip entirely inside
# the window; the next persist (or a forced one) ships the latest bytes.
now = time.monotonic()
if not force and now - self._last_persist < _PERSIST_INTERVAL:
if self._notify:
self._notify(self.spec.download_id)
return
self._last_persist = now
# SQLite is blocking; run it off the event loop per the queries module
# contract so progress persists don't stall the web server.
await asyncio.to_thread(self._write_progress)
if self._notify:
self._notify(self.spec.download_id)
def _write_progress(self) -> None:
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
for seg in self.state.segments:
if seg.end >= seg.start: # skip unknown-size sentinel
queries.update_segment_progress(
self.spec.download_id, seg.idx, seg.bytes_done
)
if self._notify:
self._notify(self.spec.download_id)
async def _reset_for_restart(self) -> None:
await self._close_writer()
@ -507,9 +521,13 @@ class DownloadJob:
for seg in self.state.segments:
seg.bytes_done = 0
self.state.bytes_done = 0
queries.update_download(self.spec.download_id, bytes_done=0)
if queries.list_segments(self.spec.download_id):
queries.replace_segments(self.spec.download_id, [])
await asyncio.to_thread(
queries.update_download, self.spec.download_id, bytes_done=0
)
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
await asyncio.to_thread(
queries.replace_segments, self.spec.download_id, []
)
async def _close_writer(self) -> None:
if self._writer is not None:
@ -543,7 +561,7 @@ class DownloadJob:
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
)
def _set_status(self, status: str, error: Optional[str] = None) -> None:
async def _set_status(self, status: str, error: Optional[str] = None) -> None:
# ``error`` is authoritative: passing None clears any prior failure
# text so transitions out of a failure state (retry/success) don't
# leave stale messages on RuntimeState or in the persisted row.
@ -553,6 +571,6 @@ class DownloadJob:
if status == DownloadStatus.QUEUED:
fields["attempts"] = self.spec.attempts + 1
self.spec.attempts += 1
queries.update_download(self.spec.download_id, **fields)
await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields)
if self._notify:
self._notify(self.spec.download_id)