mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Switch to asyncio.to_thread for db calls in job.py
This commit is contained in:
parent
4ccaaa6f37
commit
c785130223
@ -149,38 +149,38 @@ class DownloadJob:
|
||||
|
||||
async def run(self) -> str:
|
||||
"""Run to a terminal/paused state; returns the final status string."""
|
||||
self._set_status(DownloadStatus.ACTIVE, error=None)
|
||||
await self._set_status(DownloadStatus.ACTIVE, error=None)
|
||||
try:
|
||||
pr = await self._probe_and_plan()
|
||||
await self._transfer(pr)
|
||||
await self._finalize()
|
||||
self._set_status(DownloadStatus.COMPLETED)
|
||||
await self._set_status(DownloadStatus.COMPLETED)
|
||||
except Paused:
|
||||
await self._persist_progress(force=True)
|
||||
self._set_status(DownloadStatus.PAUSED)
|
||||
await self._set_status(DownloadStatus.PAUSED)
|
||||
except Cancelled:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
self._set_status(DownloadStatus.CANCELLED)
|
||||
await self._set_status(DownloadStatus.CANCELLED)
|
||||
except RemoteChanged:
|
||||
await self._reset_for_restart()
|
||||
self._set_status(
|
||||
await self._set_status(
|
||||
DownloadStatus.QUEUED, error="remote file changed; restarting"
|
||||
)
|
||||
except RetryableError as e:
|
||||
await self._persist_progress(force=True)
|
||||
self._set_status(DownloadStatus.QUEUED, error=str(e))
|
||||
await self._set_status(DownloadStatus.QUEUED, error=str(e))
|
||||
except FatalError as e:
|
||||
await self._close_writer()
|
||||
self._remove_temp()
|
||||
self._set_status(DownloadStatus.FAILED, error=str(e))
|
||||
await self._set_status(DownloadStatus.FAILED, error=str(e))
|
||||
except Exception as e: # unexpected -> treat as retryable
|
||||
logging.warning(
|
||||
"[model_downloader] %s unexpected error: %s",
|
||||
self.spec.model_id, e, exc_info=True,
|
||||
)
|
||||
await self._persist_progress(force=True)
|
||||
self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
|
||||
await self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
|
||||
finally:
|
||||
await self._close_writer()
|
||||
return self.state.status
|
||||
@ -208,7 +208,8 @@ class DownloadJob:
|
||||
|
||||
self._etag = pr.etag or self._etag
|
||||
self.state.total_bytes = pr.total_bytes
|
||||
queries.update_download(
|
||||
await asyncio.to_thread(
|
||||
queries.update_download,
|
||||
self.spec.download_id,
|
||||
final_url=pr.final_url,
|
||||
total_bytes=pr.total_bytes,
|
||||
@ -220,7 +221,7 @@ class DownloadJob:
|
||||
seg_count = effective_segment_count(
|
||||
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
|
||||
)
|
||||
existing = queries.list_segments(self.spec.download_id)
|
||||
existing = await asyncio.to_thread(queries.list_segments, self.spec.download_id)
|
||||
if (
|
||||
seg_count > 1
|
||||
and existing
|
||||
@ -234,7 +235,8 @@ class DownloadJob:
|
||||
]
|
||||
elif seg_count > 1 and pr.total_bytes is not None:
|
||||
plans = plan_segments(pr.total_bytes, seg_count)
|
||||
queries.replace_segments(
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments,
|
||||
self.spec.download_id,
|
||||
[
|
||||
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
|
||||
@ -244,7 +246,7 @@ class DownloadJob:
|
||||
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
|
||||
else:
|
||||
# Single-stream: one logical segment; bytes_done tracked on the row.
|
||||
row = queries.get_download(self.spec.download_id)
|
||||
row = await asyncio.to_thread(queries.get_download, self.spec.download_id)
|
||||
resume_from = row.bytes_done if row else 0
|
||||
end = (pr.total_bytes - 1) if pr.total_bytes else -1
|
||||
# ``row.bytes_done`` may be the SUM of per-segment offsets from a
|
||||
@ -263,9 +265,13 @@ class DownloadJob:
|
||||
)
|
||||
resume_from = 0
|
||||
self._remove_temp()
|
||||
if queries.list_segments(self.spec.download_id):
|
||||
queries.replace_segments(self.spec.download_id, [])
|
||||
queries.update_download(self.spec.download_id, bytes_done=0)
|
||||
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
|
||||
self._recompute_bytes_done()
|
||||
return pr
|
||||
@ -419,7 +425,7 @@ class DownloadJob:
|
||||
async def _finalize(self) -> None:
|
||||
self._check_control()
|
||||
await self._close_writer()
|
||||
self._set_status(DownloadStatus.VERIFYING)
|
||||
await self._set_status(DownloadStatus.VERIFYING)
|
||||
|
||||
total = self.state.total_bytes
|
||||
segmented = len(self.state.segments) > 1
|
||||
@ -486,20 +492,28 @@ class DownloadJob:
|
||||
self.state._last_time = now
|
||||
|
||||
async def _persist_progress(self, force: bool = False) -> None:
|
||||
# Both the DB write and the websocket notify are gated by the same
|
||||
# throttle: persisting hits SQLite, and notifying broadcasts to every
|
||||
# client, so doing either per-chunk (small --download-chunk-size or
|
||||
# many concurrent segments) would overwhelm both. Skip entirely inside
|
||||
# the window; the next persist (or a forced one) ships the latest bytes.
|
||||
now = time.monotonic()
|
||||
if not force and now - self._last_persist < _PERSIST_INTERVAL:
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
return
|
||||
self._last_persist = now
|
||||
# SQLite is blocking; run it off the event loop per the queries module
|
||||
# contract so progress persists don't stall the web server.
|
||||
await asyncio.to_thread(self._write_progress)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
|
||||
def _write_progress(self) -> None:
|
||||
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
|
||||
for seg in self.state.segments:
|
||||
if seg.end >= seg.start: # skip unknown-size sentinel
|
||||
queries.update_segment_progress(
|
||||
self.spec.download_id, seg.idx, seg.bytes_done
|
||||
)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
|
||||
async def _reset_for_restart(self) -> None:
|
||||
await self._close_writer()
|
||||
@ -507,9 +521,13 @@ class DownloadJob:
|
||||
for seg in self.state.segments:
|
||||
seg.bytes_done = 0
|
||||
self.state.bytes_done = 0
|
||||
queries.update_download(self.spec.download_id, bytes_done=0)
|
||||
if queries.list_segments(self.spec.download_id):
|
||||
queries.replace_segments(self.spec.download_id, [])
|
||||
await asyncio.to_thread(
|
||||
queries.update_download, self.spec.download_id, bytes_done=0
|
||||
)
|
||||
if await asyncio.to_thread(queries.list_segments, self.spec.download_id):
|
||||
await asyncio.to_thread(
|
||||
queries.replace_segments, self.spec.download_id, []
|
||||
)
|
||||
|
||||
async def _close_writer(self) -> None:
|
||||
if self._writer is not None:
|
||||
@ -543,7 +561,7 @@ class DownloadJob:
|
||||
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
|
||||
)
|
||||
|
||||
def _set_status(self, status: str, error: Optional[str] = None) -> None:
|
||||
async def _set_status(self, status: str, error: Optional[str] = None) -> None:
|
||||
# ``error`` is authoritative: passing None clears any prior failure
|
||||
# text so transitions out of a failure state (retry/success) don't
|
||||
# leave stale messages on RuntimeState or in the persisted row.
|
||||
@ -553,6 +571,6 @@ class DownloadJob:
|
||||
if status == DownloadStatus.QUEUED:
|
||||
fields["attempts"] = self.spec.attempts + 1
|
||||
self.spec.attempts += 1
|
||||
queries.update_download(self.spec.download_id, **fields)
|
||||
await asyncio.to_thread(queries.update_download, self.spec.download_id, **fields)
|
||||
if self._notify:
|
||||
self._notify(self.spec.download_id)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user