ComfyUI/app/model_downloader/download_server.py
DoronGenzelHass fdd84d04a0 feat(model_downloader): server-side model download + HuggingFace OAuth subsystem
Self-contained package under app/model_downloader/:
- Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension).
- Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep.
- Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll.
- HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh.
- Pydantic request/response schemas and aiohttp routes under api/.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-22 15:16:59 +03:00

180 lines
6.6 KiB
Python

"""Process-wide registry of in-flight model downloads.
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
server-side model fetch. Designed to be safe with multiple concurrent
clients hitting the API: each model_id has at most one active session,
and the API rejects requests that conflict with in-flight downloads.
Cancellation is cooperative — the download loop checks ``is_active`` on
its own session between chunks and raises ``DownloadCancelled`` when the
session has been removed from the registry. This avoids the complications
of ``Task.cancel()`` from outside the loop while still giving deterministic
rollback semantics (the worker is responsible for deleting its own
``.tmp`` on the cancel path).
"""
from __future__ import annotations
import logging
import os
import threading
from dataclasses import dataclass, field
from typing import Optional
from app.model_downloader.paths import iter_all_tmp_paths
class DownloadCancelled(Exception):
"""Raised by the streaming loop when its session has been removed
from the registry (cancellation request) and the worker should roll
back its ``.tmp`` file."""
@dataclass
class DownloadSession:
"""One in-flight download.
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
byte arrives and we know whether the response carries a
``Content-Length``. ``total_bytes`` mirrors that header when present.
"""
model_id: str
url: str
progress: Optional[float] = None
bytes_downloaded: int = 0
total_bytes: Optional[int] = None
# Sequence number used solely as identity for the cancellation check —
# so that "cancel + restart" doesn't get confused by stale workers.
epoch: int = field(default_factory=lambda: 0)
class DownloadServer:
"""Singleton registry of active downloads.
All mutation goes through this object so concurrent route handlers
see a consistent view. The ``_lock`` is a plain threading lock
because the registry is consulted from both the asyncio event-loop
thread (route handlers) and from any worker coroutines spawned to
perform downloads.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._sessions: dict[str, DownloadSession] = {}
self._epoch_counter = 0
self._orphan_sweep_done = False
# ----- lifecycle -----
def sweep_orphan_tmp_files(self) -> None:
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
Deferred off the import path so module load doesn't block on
filesystem I/O against potentially-slow mounts. Each route handler
that might create a new ``.tmp`` runs this exactly once.
"""
with self._lock:
if self._orphan_sweep_done:
return
self._orphan_sweep_done = True
for path in iter_all_tmp_paths():
try:
os.remove(path)
logging.info("[model_downloader] removed orphan tmp file: %s", path)
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
# ----- queries -----
def is_downloading(self, model_id: str) -> bool:
with self._lock:
return model_id in self._sessions
def get(self, model_id: str) -> Optional[DownloadSession]:
with self._lock:
return self._sessions.get(model_id)
def snapshot(self) -> dict[str, DownloadSession]:
"""Return a shallow copy of the current sessions map."""
with self._lock:
return dict(self._sessions)
# ----- mutations -----
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
"""Atomically register a new session iff none exists for ``model_id``.
Returns the new session on success, ``None`` if a session is already
in flight. Callers must check the return value — the caller is the
sole owner of the session it gets back.
"""
with self._lock:
if model_id in self._sessions:
return None
self._epoch_counter += 1
session = DownloadSession(
model_id=model_id,
url=url,
epoch=self._epoch_counter,
)
self._sessions[model_id] = session
return session
def update_progress(
self,
session: DownloadSession,
bytes_downloaded: int,
total_bytes: Optional[int],
) -> None:
"""Update progress on a session. No-op if the session has been
removed (cancelled) — caller should check ``is_active`` separately."""
with self._lock:
current = self._sessions.get(session.model_id)
if current is None or current.epoch != session.epoch:
return
current.bytes_downloaded = bytes_downloaded
current.total_bytes = total_bytes
if total_bytes and total_bytes > 0:
current.progress = min(1.0, bytes_downloaded / total_bytes)
def is_active(self, session: DownloadSession) -> bool:
"""True iff this exact session is still the registered one for
its model_id. False after cancellation, after completion, or if
another session has replaced it."""
with self._lock:
current = self._sessions.get(session.model_id)
return current is not None and current.epoch == session.epoch
def finish(self, session: DownloadSession) -> None:
"""Remove a completed (or cancelled) session from the registry.
Safe to call multiple times. Only removes if the epoch matches —
we never accidentally evict a *newer* session for the same model_id.
"""
with self._lock:
current = self._sessions.get(session.model_id)
if current is not None and current.epoch == session.epoch:
del self._sessions[session.model_id]
def reset_for_tests(self) -> None:
"""Clear all sessions and reset the epoch counter. Test-only."""
with self._lock:
self._sessions.clear()
self._epoch_counter = 0
def cancel(self, model_id: str) -> bool:
"""Remove the session registered for ``model_id``.
Returns True if there was an active session to cancel. The worker
will discover the cancellation on its next ``is_active`` check
and roll back its ``.tmp`` file.
"""
with self._lock:
if model_id in self._sessions:
del self._sessions[model_id]
return True
return False
DOWNLOAD_SERVER = DownloadServer()