mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Self-contained package under app/model_downloader/: - Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension). - Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep. - Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll. - HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh. - Pydantic request/response schemas and aiohttp routes under api/. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
180 lines
6.6 KiB
Python
180 lines
6.6 KiB
Python
"""Process-wide registry of in-flight model downloads.
|
|
|
|
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
|
|
server-side model fetch. Designed to be safe with multiple concurrent
|
|
clients hitting the API: each model_id has at most one active session,
|
|
and the API rejects requests that conflict with in-flight downloads.
|
|
|
|
Cancellation is cooperative — the download loop checks ``is_active`` on
|
|
its own session between chunks and raises ``DownloadCancelled`` when the
|
|
session has been removed from the registry. This avoids the complications
|
|
of ``Task.cancel()`` from outside the loop while still giving deterministic
|
|
rollback semantics (the worker is responsible for deleting its own
|
|
``.tmp`` on the cancel path).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
from app.model_downloader.paths import iter_all_tmp_paths
|
|
|
|
|
|
class DownloadCancelled(Exception):
|
|
"""Raised by the streaming loop when its session has been removed
|
|
from the registry (cancellation request) and the worker should roll
|
|
back its ``.tmp`` file."""
|
|
|
|
|
|
@dataclass
|
|
class DownloadSession:
|
|
"""One in-flight download.
|
|
|
|
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
|
|
byte arrives and we know whether the response carries a
|
|
``Content-Length``. ``total_bytes`` mirrors that header when present.
|
|
"""
|
|
model_id: str
|
|
url: str
|
|
progress: Optional[float] = None
|
|
bytes_downloaded: int = 0
|
|
total_bytes: Optional[int] = None
|
|
# Sequence number used solely as identity for the cancellation check —
|
|
# so that "cancel + restart" doesn't get confused by stale workers.
|
|
epoch: int = field(default_factory=lambda: 0)
|
|
|
|
|
|
class DownloadServer:
|
|
"""Singleton registry of active downloads.
|
|
|
|
All mutation goes through this object so concurrent route handlers
|
|
see a consistent view. The ``_lock`` is a plain threading lock
|
|
because the registry is consulted from both the asyncio event-loop
|
|
thread (route handlers) and from any worker coroutines spawned to
|
|
perform downloads.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = threading.Lock()
|
|
self._sessions: dict[str, DownloadSession] = {}
|
|
self._epoch_counter = 0
|
|
self._orphan_sweep_done = False
|
|
|
|
# ----- lifecycle -----
|
|
|
|
def sweep_orphan_tmp_files(self) -> None:
|
|
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
|
|
|
|
Deferred off the import path so module load doesn't block on
|
|
filesystem I/O against potentially-slow mounts. Each route handler
|
|
that might create a new ``.tmp`` runs this exactly once.
|
|
"""
|
|
with self._lock:
|
|
if self._orphan_sweep_done:
|
|
return
|
|
self._orphan_sweep_done = True
|
|
for path in iter_all_tmp_paths():
|
|
try:
|
|
os.remove(path)
|
|
logging.info("[model_downloader] removed orphan tmp file: %s", path)
|
|
except OSError as e:
|
|
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
|
|
|
# ----- queries -----
|
|
|
|
def is_downloading(self, model_id: str) -> bool:
|
|
with self._lock:
|
|
return model_id in self._sessions
|
|
|
|
def get(self, model_id: str) -> Optional[DownloadSession]:
|
|
with self._lock:
|
|
return self._sessions.get(model_id)
|
|
|
|
def snapshot(self) -> dict[str, DownloadSession]:
|
|
"""Return a shallow copy of the current sessions map."""
|
|
with self._lock:
|
|
return dict(self._sessions)
|
|
|
|
# ----- mutations -----
|
|
|
|
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
|
|
"""Atomically register a new session iff none exists for ``model_id``.
|
|
|
|
Returns the new session on success, ``None`` if a session is already
|
|
in flight. Callers must check the return value — the caller is the
|
|
sole owner of the session it gets back.
|
|
"""
|
|
with self._lock:
|
|
if model_id in self._sessions:
|
|
return None
|
|
self._epoch_counter += 1
|
|
session = DownloadSession(
|
|
model_id=model_id,
|
|
url=url,
|
|
epoch=self._epoch_counter,
|
|
)
|
|
self._sessions[model_id] = session
|
|
return session
|
|
|
|
def update_progress(
|
|
self,
|
|
session: DownloadSession,
|
|
bytes_downloaded: int,
|
|
total_bytes: Optional[int],
|
|
) -> None:
|
|
"""Update progress on a session. No-op if the session has been
|
|
removed (cancelled) — caller should check ``is_active`` separately."""
|
|
with self._lock:
|
|
current = self._sessions.get(session.model_id)
|
|
if current is None or current.epoch != session.epoch:
|
|
return
|
|
current.bytes_downloaded = bytes_downloaded
|
|
current.total_bytes = total_bytes
|
|
if total_bytes and total_bytes > 0:
|
|
current.progress = min(1.0, bytes_downloaded / total_bytes)
|
|
|
|
def is_active(self, session: DownloadSession) -> bool:
|
|
"""True iff this exact session is still the registered one for
|
|
its model_id. False after cancellation, after completion, or if
|
|
another session has replaced it."""
|
|
with self._lock:
|
|
current = self._sessions.get(session.model_id)
|
|
return current is not None and current.epoch == session.epoch
|
|
|
|
def finish(self, session: DownloadSession) -> None:
|
|
"""Remove a completed (or cancelled) session from the registry.
|
|
|
|
Safe to call multiple times. Only removes if the epoch matches —
|
|
we never accidentally evict a *newer* session for the same model_id.
|
|
"""
|
|
with self._lock:
|
|
current = self._sessions.get(session.model_id)
|
|
if current is not None and current.epoch == session.epoch:
|
|
del self._sessions[session.model_id]
|
|
|
|
def reset_for_tests(self) -> None:
|
|
"""Clear all sessions and reset the epoch counter. Test-only."""
|
|
with self._lock:
|
|
self._sessions.clear()
|
|
self._epoch_counter = 0
|
|
|
|
def cancel(self, model_id: str) -> bool:
|
|
"""Remove the session registered for ``model_id``.
|
|
|
|
Returns True if there was an active session to cancel. The worker
|
|
will discover the cancellation on its next ``is_active`` check
|
|
and roll back its ``.tmp`` file.
|
|
"""
|
|
with self._lock:
|
|
if model_id in self._sessions:
|
|
del self._sessions[model_id]
|
|
return True
|
|
return False
|
|
|
|
|
|
DOWNLOAD_SERVER = DownloadServer()
|