ComfyUI/app/model_downloader/engine/job.py
2026-06-30 20:33:15 +02:00

444 lines
16 KiB
Python

"""The per-download worker (PRD sections 5, 6, 8, 12).
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
completion. It supports cooperative pause / resume / cancel, segmented
multi-connection transfer with positioned writes, and a verification gate
(size + structural + optional sha256) before the atomic rename into place.
Control is cooperative: external callers flip ``_control`` via
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
chunks and raise, which unwinds cleanly and persists resume offsets.
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Callable, Optional
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.planner import (
SegmentPlan,
effective_segment_count,
plan_segments,
)
from app.model_downloader.engine.writer import FileWriter
from app.model_downloader.net.http import open_validated
from app.model_downloader.net.probe import probe
from app.model_downloader.verify import checksum, dedup, structural
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
class Paused(Exception):
pass
class Cancelled(Exception):
pass
class RemoteChanged(Exception):
"""The remote file changed under a resume (got 200 where 206 expected)."""
class RetryableError(Exception):
pass
class FatalError(Exception):
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
@dataclass
class SegmentRuntime:
idx: int
start: int
end: int # inclusive; may be -1 for unknown-size single stream
bytes_done: int = 0
@property
def length(self) -> int:
return self.end - self.start + 1
@dataclass
class RuntimeState:
download_id: str
model_id: str
url: str
priority: int
status: str
total_bytes: Optional[int] = None
bytes_done: int = 0
error: Optional[str] = None
segments: list[SegmentRuntime] = field(default_factory=list)
started_at: float = field(default_factory=time.monotonic)
_last_bytes: int = 0
_last_time: float = field(default_factory=time.monotonic)
speed_bps: float = 0.0
@property
def progress(self) -> Optional[float]:
if not self.total_bytes:
return None
return min(1.0, self.bytes_done / self.total_bytes)
@property
def eta_seconds(self) -> Optional[float]:
if not self.total_bytes or self.speed_bps <= 0:
return None
remaining = max(0, self.total_bytes - self.bytes_done)
return remaining / self.speed_bps
@dataclass
class JobSpec:
download_id: str
url: str
model_id: str
dest_path: str
temp_path: str
priority: int = 0
credential_id: Optional[str] = None
expected_sha256: Optional[str] = None
allow_any_extension: bool = False
etag: Optional[str] = None
attempts: int = 0
class DownloadJob:
def __init__(
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
) -> None:
self.spec = spec
self._notify = notify_cb
self._control = "run" # run | pause | cancel
self.state = RuntimeState(
download_id=spec.download_id,
model_id=spec.model_id,
url=spec.url,
priority=spec.priority,
status=DownloadStatus.QUEUED,
)
self._writer: Optional[FileWriter] = None
self._etag: Optional[str] = spec.etag
self._last_persist = 0.0
# ----- external control -----
def request_pause(self) -> None:
if self._control == "run":
self._control = "pause"
def request_cancel(self) -> None:
self._control = "cancel"
def _check_control(self) -> None:
if self._control == "cancel":
raise Cancelled()
if self._control == "pause":
raise Paused()
# ----- lifecycle -----
async def run(self) -> str:
"""Run to a terminal/paused state; returns the final status string."""
self._set_status(DownloadStatus.ACTIVE, error=None)
try:
pr = await self._probe_and_plan()
await self._transfer(pr)
await self._finalize()
self._set_status(DownloadStatus.COMPLETED)
except Paused:
await self._persist_progress(force=True)
self._set_status(DownloadStatus.PAUSED)
except Cancelled:
await self._close_writer()
self._remove_temp()
self._set_status(DownloadStatus.CANCELLED)
except RemoteChanged:
await self._reset_for_restart()
self._set_status(
DownloadStatus.QUEUED, error="remote file changed; restarting"
)
except RetryableError as e:
await self._persist_progress(force=True)
self._set_status(DownloadStatus.QUEUED, error=str(e))
except FatalError as e:
await self._close_writer()
self._remove_temp()
self._set_status(DownloadStatus.FAILED, error=str(e))
except Exception as e: # unexpected -> treat as retryable
logging.warning(
"[model_downloader] %s unexpected error: %s",
self.spec.model_id, e, exc_info=True,
)
await self._persist_progress(force=True)
self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
finally:
await self._close_writer()
return self.state.status
# ----- probe + plan -----
async def _probe_and_plan(self):
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
if not pr.ok:
if pr.gated:
raise FatalError(
f"{self.spec.url} requires authentication. Add an API key for "
f"this host at /api/download/credentials and retry."
)
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
raise RetryableError(pr.error or "probe failed")
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
self._etag = pr.etag or self._etag
self.state.total_bytes = pr.total_bytes
queries.update_download(
self.spec.download_id,
final_url=pr.final_url,
total_bytes=pr.total_bytes,
accept_ranges=pr.accept_ranges,
etag=pr.etag,
last_modified=pr.last_modified,
)
seg_count = effective_segment_count(
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
)
existing = queries.list_segments(self.spec.download_id)
if (
seg_count > 1
and existing
and pr.total_bytes is not None
and existing[-1].end_offset == pr.total_bytes - 1
):
# Resume an existing segmented plan.
self.state.segments = [
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
for s in existing
]
elif seg_count > 1 and pr.total_bytes is not None:
plans = plan_segments(pr.total_bytes, seg_count)
queries.replace_segments(
self.spec.download_id,
[
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
for p in plans
],
)
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
else:
# Single-stream: one logical segment; bytes_done tracked on the row.
row = queries.get_download(self.spec.download_id)
resume_from = row.bytes_done if row else 0
end = (pr.total_bytes - 1) if pr.total_bytes else -1
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
self._recompute_bytes_done()
return pr
# ----- transfer -----
async def _transfer(self, pr) -> None:
self._writer = FileWriter(self.spec.temp_path)
await self._writer.open()
segmented = len(self.state.segments) > 1
if segmented and self.state.total_bytes:
await self._writer.preallocate(self.state.total_bytes)
await self._run_segmented()
else:
await self._run_single()
await self._writer.flush()
async def _run_segmented(self) -> None:
pending = [
asyncio.ensure_future(self._run_segment(seg))
for seg in self.state.segments
if seg.bytes_done < seg.length
]
if not pending:
return
done, not_done = await asyncio.wait(
pending, return_when=asyncio.FIRST_EXCEPTION
)
first_exc: Optional[BaseException] = None
for task in done:
exc = task.exception()
if exc is not None and first_exc is None:
first_exc = exc
if first_exc is not None:
for task in not_done:
task.cancel()
await asyncio.gather(*not_done, return_exceptions=True)
raise first_exc
async def _run_segment(self, seg: SegmentRuntime) -> None:
offset = seg.start + seg.bytes_done
headers = {
"Range": f"bytes={offset}-{seg.end}",
"Accept-Encoding": "identity",
}
if self._etag:
headers["If-Range"] = self._etag
async with open_validated(
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
) as (resp, _final):
if resp.status == 200:
# Server ignored the range -> remote changed / no resume support.
raise RemoteChanged()
if resp.status not in (206,):
self._raise_for_status(resp.status)
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done += len(chunk)
self._recompute_bytes_done()
await self._persist_progress()
async def _run_single(self) -> None:
seg = self.state.segments[0]
offset = seg.bytes_done # resume from here for single-stream
headers = {"Accept-Encoding": "identity"}
if offset > 0:
headers["Range"] = f"bytes={offset}-"
if self._etag:
headers["If-Range"] = self._etag
async with open_validated(
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
) as (resp, _final):
if offset > 0 and resp.status == 200:
# Resume not honoured -> start over from the beginning.
offset = 0
seg.bytes_done = 0
elif offset > 0 and resp.status != 206:
self._raise_for_status(resp.status)
elif offset == 0 and resp.status != 200:
self._raise_for_status(resp.status)
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done = offset
self.state.bytes_done = offset
await self._persist_progress()
def _raise_for_status(self, status: int) -> None:
if status in (401, 403):
raise FatalError(
f"{self.spec.url} returned {status}; add/update an API key for "
f"this host at /api/download/credentials."
)
if status in _RETRYABLE_STATUSES:
raise RetryableError(f"HTTP {status}")
raise FatalError(f"unexpected HTTP {status}")
# ----- finalize / verify (PRD section 8.4) -----
async def _finalize(self) -> None:
self._check_control()
await self._close_writer()
self._set_status(DownloadStatus.VERIFYING)
total = self.state.total_bytes
actual_size = os.path.getsize(self.spec.temp_path)
if total is not None and actual_size != total:
raise FatalError(
f"size mismatch: wrote {actual_size} of {total} bytes"
)
# Structural gate (cheap, no full read) then optional sha256 (full read).
await asyncio.to_thread(structural.validate, self.spec.temp_path)
if self.spec.expected_sha256:
await asyncio.to_thread(
checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256
)
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
os.replace(self.spec.temp_path, self.spec.dest_path)
logging.info(
"[model_downloader] completed %s (%d bytes)",
self.spec.model_id, actual_size,
)
# Catalog into the assets system (blake3 dedup identity). Best-effort.
await dedup.register_completed(self.spec.dest_path)
# ----- helpers -----
def _recompute_bytes_done(self) -> None:
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
now = time.monotonic()
dt = now - self.state._last_time
if dt >= 0.5:
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
self.state._last_bytes = self.state.bytes_done
self.state._last_time = now
async def _persist_progress(self, force: bool = False) -> None:
now = time.monotonic()
if not force and now - self._last_persist < _PERSIST_INTERVAL:
if self._notify:
self._notify(self.spec.download_id)
return
self._last_persist = now
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
for seg in self.state.segments:
if seg.end >= seg.start: # skip unknown-size sentinel
queries.update_segment_progress(
self.spec.download_id, seg.idx, seg.bytes_done
)
if self._notify:
self._notify(self.spec.download_id)
async def _reset_for_restart(self) -> None:
await self._close_writer()
self._remove_temp()
for seg in self.state.segments:
seg.bytes_done = 0
self.state.bytes_done = 0
queries.update_download(self.spec.download_id, bytes_done=0)
if queries.list_segments(self.spec.download_id):
queries.replace_segments(self.spec.download_id, [])
async def _close_writer(self) -> None:
if self._writer is not None:
try:
await self._writer.close()
except Exception:
logging.debug("[model_downloader] writer close error", exc_info=True)
self._writer = None
def _remove_temp(self) -> None:
try:
os.remove(self.spec.temp_path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning(
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
)
def _set_status(self, status: str, error: Optional[str] = None) -> None:
self.state.status = status
if error is not None:
self.state.error = error
fields = {"status": status, "bytes_done": self.state.bytes_done}
if error is not None:
fields["error"] = error
if status == DownloadStatus.QUEUED:
fields["attempts"] = self.spec.attempts + 1
self.spec.attempts += 1
queries.update_download(self.spec.download_id, **fields)
if self._notify:
self._notify(self.spec.download_id)