ComfyUI/app/model_downloader/downloader.py
DoronGenzelHass fdd84d04a0 feat(model_downloader): server-side model download + HuggingFace OAuth subsystem
Self-contained package under app/model_downloader/:
- Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension).
- Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep.
- Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll.
- HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh.
- Pydantic request/response schemas and aiohttp routes under api/.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-22 15:16:59 +03:00

206 lines
7.8 KiB
Python

"""Streaming download worker with progress reporting and cancellation.
Each download writes to ``<final_path>.tmp`` and atomically renames into
place on success. Between chunks the worker checks the registry for
cancellation (via ``DownloadServer.is_active``) and rolls back its
``.tmp`` on cancel or on any error.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
import aiohttp
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadCancelled,
DownloadSession,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url
from app.model_downloader.http_client import get_session, parse_content_length
from app.model_downloader.paths import resolve_destination
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
async def stream_to_disk(session: DownloadSession) -> str:
"""Run a single download to completion or cancellation.
Returns the final on-disk path on success. Removes the ``.tmp`` and
raises on cancellation or failure. The session is finished
(removed from the registry) exactly once, here — callers do not
need to call ``DOWNLOAD_SERVER.finish`` themselves.
"""
final_path, tmp_path = resolve_destination(session.model_id)
os.makedirs(os.path.dirname(final_path), exist_ok=True)
# Wipe any stale .tmp from a previous failed attempt before we start —
# otherwise a partial body could masquerade as our completed download
# when the rename finally happens.
_remove_if_exists(tmp_path)
bytes_seen = 0
try:
http = await get_session()
headers = _auth_headers_for(session.url)
logging.info(
"[model_downloader] starting GET %s (auth=%s)",
session.url, "yes" if "Authorization" in headers else "no",
)
async with http.get(
session.url,
allow_redirects=True,
timeout=REQUEST_TIMEOUT,
headers=headers,
) as resp:
if resp.status != 200:
# Capture a snippet of the response body so 4xx/5xx aren't
# opaque in the logs — HF returns JSON or HTML with a
# human-readable reason on failures.
body_snippet = await _read_short(resp)
logging.warning(
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
session.url, resp.status, str(resp.url), body_snippet,
)
raise DownloadError(
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
status=resp.status,
)
total = parse_content_length(resp.headers.get("Content-Length"))
DOWNLOAD_SERVER.update_progress(session, 0, total)
with open(tmp_path, "wb") as f:
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
# Cancellation check between chunks. Cheap and means
# cancellation latency is bounded by one chunk plus
# one ``write()`` — typically well under a second
# even on slow disks.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
f.write(chunk)
bytes_seen += len(chunk)
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
# Final cancellation check before we promote the .tmp to the real
# filename — avoids the awkward case where cancel arrives during
# the very last chunk and we'd otherwise commit anyway.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
# Atomic rename. os.replace is atomic within the same filesystem,
# which is guaranteed here because tmp lives alongside final_path.
os.replace(tmp_path, final_path)
logging.info(
"[model_downloader] downloaded %s (%d bytes) from %s",
session.model_id, bytes_seen, session.url,
)
return final_path
except DownloadCancelled:
logging.info("[model_downloader] cancelled: %s", session.model_id)
_remove_if_exists(tmp_path)
raise
except Exception as e:
logging.warning(
"[model_downloader] failed: %s from %s: %s: %s",
session.model_id, session.url, type(e).__name__, e,
exc_info=True,
)
_remove_if_exists(tmp_path)
raise
finally:
# In all terminal states (success / cancel / error) drop the
# session from the registry. Idempotent — only removes if we're
# still the live epoch for this model_id.
DOWNLOAD_SERVER.finish(session)
class DownloadError(Exception):
"""Network / protocol error during a download."""
def __init__(self, message: str, status: Optional[int] = None) -> None:
super().__init__(message)
self.status = status
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
"""Read up to ``limit`` bytes of a response body for logging.
Used to surface the JSON/HTML reason from an HF non-2xx response in
server logs instead of just the status code. Best-effort: any
error here is swallowed.
"""
try:
raw = await resp.content.read(limit)
return raw.decode("utf-8", errors="replace").strip()
except Exception:
return "<unreadable>"
def _auth_headers_for(url: str) -> dict[str, str]:
"""Return any auth headers we should add to the GET for ``url``.
For HuggingFace URLs we inject the user's OAuth access token as a
Bearer header — this is HF's documented way to access gated repos
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
other host we send no extra headers; allowlisted public files
don't need them and we don't want to leak tokens to other hosts.
"""
if not is_hf_url(url):
return {}
tok = HF_AUTH_STORE.get_token_sync()
if tok is None or not tok.access_token:
return {}
return {"Authorization": f"Bearer {tok.access_token}"}
def _remove_if_exists(path: str) -> None:
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
"""Run a list of sessions one after the other.
Each session is independent: a failure or cancellation of one does
not abort the rest. Cancellations are observable via the registry
*before* a given download starts, so a session that's been
pre-cancelled (cancel before the worker reached it) just gets skipped.
"""
for session in sessions:
# If the session got cancelled before its turn, skip without
# touching disk. This is what makes the per-request "sequential
# but cancellable" semantic work.
if not DOWNLOAD_SERVER.is_active(session):
DOWNLOAD_SERVER.finish(session)
continue
try:
await stream_to_disk(session)
except DownloadCancelled:
# Already logged + tmp removed inside stream_to_disk.
continue
except Exception:
# stream_to_disk already logged. Continue with the rest of the batch.
continue
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
"""Kick off ``run_batch_sequential`` on the running event loop.
Returned task is fire-and-forget; the API handler returns immediately
after scheduling and clients observe progress via the polling endpoints.
"""
return asyncio.create_task(run_batch_sequential(sessions))