mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
161 lines
5.5 KiB
Python
161 lines
5.5 KiB
Python
"""Priority scheduler + lifecycle (PRD sections 4, 6, 12).
|
|
|
|
Owns the set of running jobs and admits queued downloads up to a global
|
|
concurrency limit (K), highest priority first, FIFO within a priority. Runs
|
|
entirely on the existing ComfyUI asyncio loop; blocking work (disk, hashing,
|
|
DB) is offloaded by the job/writer layers.
|
|
|
|
On startup it reconciles DB vs. disk: ``active``/``verifying`` rows left by a
|
|
previous run are reset to ``queued`` and resumed from persisted offsets, and
|
|
orphaned ``.part`` files with no live download row are swept.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import random
|
|
import time
|
|
from typing import Callable, Optional
|
|
|
|
from comfy.cli_args import args
|
|
from app.model_downloader.constants import DownloadStatus
|
|
from app.model_downloader.database import queries
|
|
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
|
from app.model_downloader.security import paths
|
|
|
|
# Backoff for retryable failures (PRD section 12).
|
|
_BACKOFF_BASE = 2.0
|
|
_BACKOFF_CAP = 300.0
|
|
_MAX_ATTEMPTS = 6
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self) -> None:
|
|
self._jobs: dict[str, DownloadJob] = {}
|
|
self._tasks: dict[str, asyncio.Task] = {}
|
|
self._backoff_until: dict[str, float] = {}
|
|
self._pump_lock = asyncio.Lock()
|
|
self._notify_cb: Optional[Callable[[str], None]] = None
|
|
self._started = False
|
|
|
|
@property
|
|
def max_active(self) -> int:
|
|
return max(1, getattr(args, "download_max_active", 3))
|
|
|
|
def set_notify(self, cb: Optional[Callable[[str], None]]) -> None:
|
|
self._notify_cb = cb
|
|
|
|
def get_job(self, download_id: str) -> Optional[DownloadJob]:
|
|
return self._jobs.get(download_id)
|
|
|
|
def is_active(self, download_id: str) -> bool:
|
|
return download_id in self._tasks
|
|
|
|
# ----- startup -----
|
|
|
|
async def start(self) -> None:
|
|
if self._started:
|
|
return
|
|
self._started = True
|
|
try:
|
|
await asyncio.to_thread(queries.reconcile_live_downloads)
|
|
await asyncio.to_thread(self._sweep_orphan_temp_files)
|
|
except Exception as e:
|
|
logging.warning("[model_downloader] startup reconcile failed: %s", e)
|
|
await self.pump()
|
|
|
|
@staticmethod
|
|
def _sweep_orphan_temp_files() -> None:
|
|
"""Remove ``.part`` files not referenced by a resumable download row.
|
|
|
|
Resumable partials (queued/paused rows) are preserved; only truly
|
|
orphaned temp files from crashed runs are deleted.
|
|
"""
|
|
live = {
|
|
row.temp_path
|
|
for row in queries.list_downloads()
|
|
if row.status in (DownloadStatus.QUEUED, DownloadStatus.PAUSED)
|
|
}
|
|
for path in paths.iter_all_tmp_paths():
|
|
if path in live:
|
|
continue
|
|
try:
|
|
os.remove(path)
|
|
logging.info("[model_downloader] removed orphan temp file: %s", path)
|
|
except OSError as e:
|
|
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
|
|
|
# ----- admission -----
|
|
|
|
async def pump(self) -> None:
|
|
async with self._pump_lock:
|
|
slots = self.max_active - len(self._tasks)
|
|
if slots <= 0:
|
|
return
|
|
now = time.monotonic()
|
|
candidates = await asyncio.to_thread(queries.list_queued_downloads)
|
|
for row in candidates:
|
|
if slots <= 0:
|
|
break
|
|
if row.id in self._tasks:
|
|
continue
|
|
if self._backoff_until.get(row.id, 0.0) > now:
|
|
continue
|
|
self._admit(row)
|
|
slots -= 1
|
|
|
|
def _admit(self, row) -> None:
|
|
spec = JobSpec(
|
|
download_id=row.id,
|
|
url=row.url,
|
|
model_id=row.model_id,
|
|
dest_path=row.dest_path,
|
|
temp_path=row.temp_path,
|
|
priority=row.priority,
|
|
credential_id=row.credential_id,
|
|
expected_sha256=row.expected_sha256,
|
|
allow_any_extension=row.allow_any_extension,
|
|
etag=row.etag,
|
|
attempts=row.attempts,
|
|
)
|
|
job = DownloadJob(spec, notify_cb=self._notify_cb)
|
|
self._jobs[row.id] = job
|
|
self._tasks[row.id] = asyncio.ensure_future(self._run_job(job))
|
|
|
|
async def _run_job(self, job: DownloadJob) -> None:
|
|
download_id = job.spec.download_id
|
|
status = DownloadStatus.FAILED
|
|
try:
|
|
status = await job.run()
|
|
except Exception as e: # run() is defensive, but never let a task die silently
|
|
logging.error("[model_downloader] job %s crashed: %s", download_id, e)
|
|
finally:
|
|
self._tasks.pop(download_id, None)
|
|
self._jobs.pop(download_id, None)
|
|
|
|
if status == DownloadStatus.QUEUED:
|
|
if job.spec.attempts >= _MAX_ATTEMPTS:
|
|
queries.update_download(
|
|
download_id,
|
|
status=DownloadStatus.FAILED,
|
|
error=f"giving up after {job.spec.attempts} attempts",
|
|
)
|
|
if self._notify_cb:
|
|
self._notify_cb(download_id)
|
|
else:
|
|
delay = min(
|
|
_BACKOFF_CAP, _BACKOFF_BASE ** job.spec.attempts
|
|
) + random.uniform(0, 1.0)
|
|
self._backoff_until[download_id] = time.monotonic() + delay
|
|
asyncio.ensure_future(self._delayed_pump(delay))
|
|
await self.pump()
|
|
|
|
async def _delayed_pump(self, delay: float) -> None:
|
|
await asyncio.sleep(delay)
|
|
await self.pump()
|
|
|
|
|
|
SCHEDULER = Scheduler()
|