mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
476 lines
18 KiB
Python
476 lines
18 KiB
Python
"""The per-download worker.
|
|
|
|
One :class:`DownloadJob` drives a single file from probe to verified, cataloged
|
|
completion. It supports cooperative pause / resume / cancel, segmented
|
|
multi-connection transfer with positioned writes, and a verification gate
|
|
(size + structural + optional sha256) before the atomic rename into place.
|
|
|
|
Control is cooperative: external callers flip ``_control`` via
|
|
:meth:`request_pause` / :meth:`request_cancel`; segment loops observe it between
|
|
chunks and raise, which unwinds cleanly and persists resume offsets.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, Optional
|
|
|
|
from comfy.cli_args import args
|
|
from app.model_downloader.constants import DownloadStatus
|
|
from app.model_downloader.database import queries
|
|
from app.model_downloader.engine.planner import (
|
|
effective_segment_count,
|
|
plan_segments,
|
|
)
|
|
from app.model_downloader.engine.writer import FileWriter
|
|
from app.model_downloader.net.http import open_validated
|
|
from app.model_downloader.net.probe import probe
|
|
from app.model_downloader.verify import checksum, dedup, structural
|
|
|
|
_RETRYABLE_STATUSES = {408, 429, 500, 502, 503, 504}
|
|
_PERSIST_INTERVAL = 2.0 # seconds between throttled progress persists
|
|
|
|
|
|
class Paused(Exception):
|
|
pass
|
|
|
|
|
|
class Cancelled(Exception):
|
|
pass
|
|
|
|
|
|
class RemoteChanged(Exception):
|
|
"""The remote file changed under a resume (got 200 where 206 expected)."""
|
|
|
|
|
|
class RetryableError(Exception):
|
|
pass
|
|
|
|
|
|
class FatalError(Exception):
|
|
"""Non-retryable: 4xx, checksum mismatch, structural failure, gated, etc."""
|
|
|
|
|
|
@dataclass
|
|
class SegmentRuntime:
|
|
idx: int
|
|
start: int
|
|
end: int # inclusive; may be -1 for unknown-size single stream
|
|
bytes_done: int = 0
|
|
|
|
@property
|
|
def length(self) -> int:
|
|
return self.end - self.start + 1
|
|
|
|
|
|
@dataclass
|
|
class RuntimeState:
|
|
download_id: str
|
|
model_id: str
|
|
url: str
|
|
priority: int
|
|
status: str
|
|
total_bytes: Optional[int] = None
|
|
bytes_done: int = 0
|
|
error: Optional[str] = None
|
|
segments: list[SegmentRuntime] = field(default_factory=list)
|
|
started_at: float = field(default_factory=time.monotonic)
|
|
_last_bytes: int = 0
|
|
_last_time: float = field(default_factory=time.monotonic)
|
|
speed_bps: float = 0.0
|
|
|
|
@property
|
|
def progress(self) -> Optional[float]:
|
|
if not self.total_bytes:
|
|
return None
|
|
return min(1.0, self.bytes_done / self.total_bytes)
|
|
|
|
@property
|
|
def eta_seconds(self) -> Optional[float]:
|
|
if not self.total_bytes or self.speed_bps <= 0:
|
|
return None
|
|
remaining = max(0, self.total_bytes - self.bytes_done)
|
|
return remaining / self.speed_bps
|
|
|
|
|
|
@dataclass
|
|
class JobSpec:
|
|
download_id: str
|
|
url: str
|
|
model_id: str
|
|
dest_path: str
|
|
temp_path: str
|
|
priority: int = 0
|
|
credential_id: Optional[str] = None
|
|
expected_sha256: Optional[str] = None
|
|
allow_any_extension: bool = False
|
|
etag: Optional[str] = None
|
|
attempts: int = 0
|
|
|
|
|
|
class DownloadJob:
|
|
def __init__(
|
|
self, spec: JobSpec, notify_cb: Optional[Callable[[str], None]] = None
|
|
) -> None:
|
|
self.spec = spec
|
|
self._notify = notify_cb
|
|
self._control = "run" # run | pause | cancel
|
|
self.state = RuntimeState(
|
|
download_id=spec.download_id,
|
|
model_id=spec.model_id,
|
|
url=spec.url,
|
|
priority=spec.priority,
|
|
status=DownloadStatus.QUEUED,
|
|
)
|
|
self._writer: Optional[FileWriter] = None
|
|
self._etag: Optional[str] = spec.etag
|
|
self._last_persist = 0.0
|
|
|
|
# ----- external control -----
|
|
|
|
def request_pause(self) -> None:
|
|
if self._control == "run":
|
|
self._control = "pause"
|
|
|
|
def request_cancel(self) -> None:
|
|
self._control = "cancel"
|
|
|
|
def _check_control(self) -> None:
|
|
if self._control == "cancel":
|
|
raise Cancelled()
|
|
if self._control == "pause":
|
|
raise Paused()
|
|
|
|
# ----- lifecycle -----
|
|
|
|
async def run(self) -> str:
|
|
"""Run to a terminal/paused state; returns the final status string."""
|
|
self._set_status(DownloadStatus.ACTIVE, error=None)
|
|
try:
|
|
pr = await self._probe_and_plan()
|
|
await self._transfer(pr)
|
|
await self._finalize()
|
|
self._set_status(DownloadStatus.COMPLETED)
|
|
except Paused:
|
|
await self._persist_progress(force=True)
|
|
self._set_status(DownloadStatus.PAUSED)
|
|
except Cancelled:
|
|
await self._close_writer()
|
|
self._remove_temp()
|
|
self._set_status(DownloadStatus.CANCELLED)
|
|
except RemoteChanged:
|
|
await self._reset_for_restart()
|
|
self._set_status(
|
|
DownloadStatus.QUEUED, error="remote file changed; restarting"
|
|
)
|
|
except RetryableError as e:
|
|
await self._persist_progress(force=True)
|
|
self._set_status(DownloadStatus.QUEUED, error=str(e))
|
|
except FatalError as e:
|
|
await self._close_writer()
|
|
self._remove_temp()
|
|
self._set_status(DownloadStatus.FAILED, error=str(e))
|
|
except Exception as e: # unexpected -> treat as retryable
|
|
logging.warning(
|
|
"[model_downloader] %s unexpected error: %s",
|
|
self.spec.model_id, e, exc_info=True,
|
|
)
|
|
await self._persist_progress(force=True)
|
|
self._set_status(DownloadStatus.QUEUED, error=f"{type(e).__name__}: {e}")
|
|
finally:
|
|
await self._close_writer()
|
|
return self.state.status
|
|
|
|
# ----- probe + plan -----
|
|
|
|
async def _probe_and_plan(self):
|
|
pr = await probe(self.spec.url, credential_id=self.spec.credential_id)
|
|
if not pr.ok:
|
|
if pr.gated:
|
|
raise FatalError(
|
|
f"{self.spec.url} requires authentication. Add an API key for "
|
|
f"this host at /api/download/credentials and retry."
|
|
)
|
|
if pr.status == 0 or pr.status in _RETRYABLE_STATUSES:
|
|
raise RetryableError(pr.error or "probe failed")
|
|
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
|
|
|
self._etag = pr.etag or self._etag
|
|
self.state.total_bytes = pr.total_bytes
|
|
queries.update_download(
|
|
self.spec.download_id,
|
|
final_url=pr.final_url,
|
|
total_bytes=pr.total_bytes,
|
|
accept_ranges=pr.accept_ranges,
|
|
etag=pr.etag,
|
|
last_modified=pr.last_modified,
|
|
)
|
|
|
|
seg_count = effective_segment_count(
|
|
pr.total_bytes, pr.accept_ranges, max(1, args.download_segments)
|
|
)
|
|
existing = queries.list_segments(self.spec.download_id)
|
|
if (
|
|
seg_count > 1
|
|
and existing
|
|
and pr.total_bytes is not None
|
|
and existing[-1].end_offset == pr.total_bytes - 1
|
|
):
|
|
# Resume an existing segmented plan.
|
|
self.state.segments = [
|
|
SegmentRuntime(s.idx, s.start_offset, s.end_offset, s.bytes_done)
|
|
for s in existing
|
|
]
|
|
elif seg_count > 1 and pr.total_bytes is not None:
|
|
plans = plan_segments(pr.total_bytes, seg_count)
|
|
queries.replace_segments(
|
|
self.spec.download_id,
|
|
[
|
|
{"idx": p.idx, "start_offset": p.start, "end_offset": p.end, "bytes_done": 0}
|
|
for p in plans
|
|
],
|
|
)
|
|
self.state.segments = [SegmentRuntime(p.idx, p.start, p.end, 0) for p in plans]
|
|
else:
|
|
# Single-stream: one logical segment; bytes_done tracked on the row.
|
|
row = queries.get_download(self.spec.download_id)
|
|
resume_from = row.bytes_done if row else 0
|
|
end = (pr.total_bytes - 1) if pr.total_bytes else -1
|
|
# ``row.bytes_done`` may be the SUM of per-segment offsets from a
|
|
# prior segmented run (a preallocated, non-contiguous .part). A
|
|
# single-stream resume writes a contiguous prefix, so the offset is
|
|
# only trustworthy when the on-disk file is exactly that many
|
|
# contiguous bytes. This guards the case where a download that ran
|
|
# segmented now resolves to one segment (server dropped
|
|
# Accept-Ranges, or --download-segments was lowered between runs):
|
|
# resuming over non-contiguous data would corrupt the output.
|
|
if resume_from > 0 and not self._contiguous_prefix_valid(resume_from):
|
|
logging.info(
|
|
"[model_downloader] %s discarding untrusted resume offset "
|
|
"%d (on-disk .part not a contiguous prefix); restarting",
|
|
self.spec.model_id, resume_from,
|
|
)
|
|
resume_from = 0
|
|
self._remove_temp()
|
|
if queries.list_segments(self.spec.download_id):
|
|
queries.replace_segments(self.spec.download_id, [])
|
|
queries.update_download(self.spec.download_id, bytes_done=0)
|
|
self.state.segments = [SegmentRuntime(0, 0, end, resume_from)]
|
|
self._recompute_bytes_done()
|
|
return pr
|
|
|
|
# ----- transfer -----
|
|
|
|
async def _transfer(self, pr) -> None:
|
|
self._writer = FileWriter(self.spec.temp_path)
|
|
await self._writer.open()
|
|
|
|
segmented = len(self.state.segments) > 1
|
|
if segmented and self.state.total_bytes:
|
|
await self._writer.preallocate(self.state.total_bytes)
|
|
await self._run_segmented()
|
|
else:
|
|
await self._run_single()
|
|
|
|
await self._writer.flush()
|
|
|
|
async def _run_segmented(self) -> None:
|
|
pending = [
|
|
asyncio.ensure_future(self._run_segment(seg))
|
|
for seg in self.state.segments
|
|
if seg.bytes_done < seg.length
|
|
]
|
|
if not pending:
|
|
return
|
|
done, not_done = await asyncio.wait(
|
|
pending, return_when=asyncio.FIRST_EXCEPTION
|
|
)
|
|
first_exc: Optional[BaseException] = None
|
|
for task in done:
|
|
exc = task.exception()
|
|
if exc is not None and first_exc is None:
|
|
first_exc = exc
|
|
if first_exc is not None:
|
|
for task in not_done:
|
|
task.cancel()
|
|
await asyncio.gather(*not_done, return_exceptions=True)
|
|
raise first_exc
|
|
|
|
async def _run_segment(self, seg: SegmentRuntime) -> None:
|
|
offset = seg.start + seg.bytes_done
|
|
headers = {
|
|
"Range": f"bytes={offset}-{seg.end}",
|
|
"Accept-Encoding": "identity",
|
|
}
|
|
if self._etag:
|
|
headers["If-Range"] = self._etag
|
|
async with open_validated(
|
|
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
|
) as (resp, _final):
|
|
if resp.status == 200:
|
|
# Server ignored the range -> remote changed / no resume support.
|
|
raise RemoteChanged()
|
|
if resp.status not in (206,):
|
|
self._raise_for_status(resp.status)
|
|
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
|
self._check_control()
|
|
await self._writer.write_at(offset, chunk)
|
|
offset += len(chunk)
|
|
seg.bytes_done += len(chunk)
|
|
self._recompute_bytes_done()
|
|
await self._persist_progress()
|
|
|
|
async def _run_single(self) -> None:
|
|
seg = self.state.segments[0]
|
|
offset = seg.bytes_done # resume from here for single-stream
|
|
headers = {"Accept-Encoding": "identity"}
|
|
if offset > 0:
|
|
headers["Range"] = f"bytes={offset}-"
|
|
if self._etag:
|
|
headers["If-Range"] = self._etag
|
|
async with open_validated(
|
|
"GET", self.spec.url, credential_id=self.spec.credential_id, headers=headers
|
|
) as (resp, _final):
|
|
if offset > 0 and resp.status == 200:
|
|
# Resume not honoured -> start over from the beginning.
|
|
offset = 0
|
|
seg.bytes_done = 0
|
|
elif offset > 0 and resp.status != 206:
|
|
self._raise_for_status(resp.status)
|
|
elif offset == 0 and resp.status != 200:
|
|
self._raise_for_status(resp.status)
|
|
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
|
self._check_control()
|
|
await self._writer.write_at(offset, chunk)
|
|
offset += len(chunk)
|
|
seg.bytes_done = offset
|
|
self.state.bytes_done = offset
|
|
await self._persist_progress()
|
|
|
|
def _raise_for_status(self, status: int) -> None:
|
|
if status in (401, 403):
|
|
raise FatalError(
|
|
f"{self.spec.url} returned {status}; add/update an API key for "
|
|
f"this host at /api/download/credentials."
|
|
)
|
|
if status in _RETRYABLE_STATUSES:
|
|
raise RetryableError(f"HTTP {status}")
|
|
raise FatalError(f"unexpected HTTP {status}")
|
|
|
|
# ----- finalize / verify (PRD section 8.4) -----
|
|
|
|
async def _finalize(self) -> None:
|
|
self._check_control()
|
|
await self._close_writer()
|
|
self._set_status(DownloadStatus.VERIFYING)
|
|
|
|
total = self.state.total_bytes
|
|
actual_size = os.path.getsize(self.spec.temp_path)
|
|
if total is not None and actual_size != total:
|
|
raise FatalError(
|
|
f"size mismatch: wrote {actual_size} of {total} bytes"
|
|
)
|
|
|
|
# Structural gate (cheap, no full read) then optional sha256 (full read).
|
|
await asyncio.to_thread(structural.validate, self.spec.temp_path)
|
|
if self.spec.expected_sha256:
|
|
await asyncio.to_thread(
|
|
checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
|
|
os.replace(self.spec.temp_path, self.spec.dest_path)
|
|
logging.info(
|
|
"[model_downloader] completed %s (%d bytes)",
|
|
self.spec.model_id, actual_size,
|
|
)
|
|
# Catalog into the assets system (blake3 dedup identity). Best-effort.
|
|
await dedup.register_completed(self.spec.dest_path)
|
|
|
|
# ----- helpers -----
|
|
|
|
def _recompute_bytes_done(self) -> None:
|
|
self.state.bytes_done = sum(s.bytes_done for s in self.state.segments)
|
|
now = time.monotonic()
|
|
dt = now - self.state._last_time
|
|
if dt >= 0.5:
|
|
self.state.speed_bps = (self.state.bytes_done - self.state._last_bytes) / dt
|
|
self.state._last_bytes = self.state.bytes_done
|
|
self.state._last_time = now
|
|
|
|
async def _persist_progress(self, force: bool = False) -> None:
|
|
now = time.monotonic()
|
|
if not force and now - self._last_persist < _PERSIST_INTERVAL:
|
|
if self._notify:
|
|
self._notify(self.spec.download_id)
|
|
return
|
|
self._last_persist = now
|
|
queries.update_download(self.spec.download_id, bytes_done=self.state.bytes_done)
|
|
for seg in self.state.segments:
|
|
if seg.end >= seg.start: # skip unknown-size sentinel
|
|
queries.update_segment_progress(
|
|
self.spec.download_id, seg.idx, seg.bytes_done
|
|
)
|
|
if self._notify:
|
|
self._notify(self.spec.download_id)
|
|
|
|
async def _reset_for_restart(self) -> None:
|
|
await self._close_writer()
|
|
self._remove_temp()
|
|
for seg in self.state.segments:
|
|
seg.bytes_done = 0
|
|
self.state.bytes_done = 0
|
|
queries.update_download(self.spec.download_id, bytes_done=0)
|
|
if queries.list_segments(self.spec.download_id):
|
|
queries.replace_segments(self.spec.download_id, [])
|
|
|
|
async def _close_writer(self) -> None:
|
|
if self._writer is not None:
|
|
try:
|
|
await self._writer.close()
|
|
except Exception:
|
|
logging.debug("[model_downloader] writer close error", exc_info=True)
|
|
self._writer = None
|
|
|
|
def _contiguous_prefix_valid(self, prefix_len: int) -> bool:
|
|
"""True when the temp file is exactly ``prefix_len`` contiguous bytes.
|
|
|
|
Single-stream resume appends sequentially, so a valid resume point
|
|
implies the .part size equals the persisted offset. A larger file (e.g.
|
|
one preallocated to ``total_bytes`` by a previous segmented run) or a
|
|
missing/short file means the persisted offset is not a trustworthy
|
|
contiguous prefix and must not be resumed over.
|
|
"""
|
|
try:
|
|
return os.path.getsize(self.spec.temp_path) == prefix_len
|
|
except OSError:
|
|
return False
|
|
|
|
def _remove_temp(self) -> None:
|
|
try:
|
|
os.remove(self.spec.temp_path)
|
|
except FileNotFoundError:
|
|
pass
|
|
except OSError as e:
|
|
logging.warning(
|
|
"[model_downloader] could not remove %s: %s", self.spec.temp_path, e
|
|
)
|
|
|
|
def _set_status(self, status: str, error: Optional[str] = None) -> None:
|
|
self.state.status = status
|
|
if error is not None:
|
|
self.state.error = error
|
|
fields = {"status": status, "bytes_done": self.state.bytes_done}
|
|
if error is not None:
|
|
fields["error"] = error
|
|
if status == DownloadStatus.QUEUED:
|
|
fields["attempts"] = self.spec.attempts + 1
|
|
self.spec.attempts += 1
|
|
queries.update_download(self.spec.download_id, **fields)
|
|
if self._notify:
|
|
self._notify(self.spec.download_id)
|