From fdd84d04a0e3006134c649f23935745ce4082ed3 Mon Sep 17 00:00:00 2001 From: DoronGenzelHass Date: Mon, 22 Jun 2026 12:02:46 +0300 Subject: [PATCH] feat(model_downloader): server-side model download + HuggingFace OAuth subsystem Self-contained package under app/model_downloader/: - Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension). - Streaming worker: writes to .tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep. - Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll. - HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh. - Pydantic request/response schemas and aiohttp routes under api/. Co-Authored-By: Claude Opus 4.8 (1M context) --- app/model_downloader/allowlist.py | 46 +++ app/model_downloader/api/routes.py | 332 ++++++++++++++++++++ app/model_downloader/api/schemas_in.py | 41 +++ app/model_downloader/api/schemas_out.py | 81 +++++ app/model_downloader/download_server.py | 179 +++++++++++ app/model_downloader/downloader.py | 205 ++++++++++++ app/model_downloader/gated_detection.py | 238 ++++++++++++++ app/model_downloader/hf_auth/auth_store.py | 106 +++++++ app/model_downloader/hf_auth/eligibility.py | 55 ++++ app/model_downloader/hf_auth/oauth.py | 277 ++++++++++++++++ app/model_downloader/hf_auth/token_store.py | 89 ++++++ app/model_downloader/hf_url.py | 41 +++ app/model_downloader/http_client.py | 63 ++++ app/model_downloader/paths.py | 93 ++++++ 14 files changed, 1846 insertions(+) create mode 100644 app/model_downloader/allowlist.py create mode 100644 app/model_downloader/api/routes.py create mode 100644 app/model_downloader/api/schemas_in.py create mode 100644 app/model_downloader/api/schemas_out.py create mode 100644 app/model_downloader/download_server.py create mode 100644 app/model_downloader/downloader.py create mode 100644 app/model_downloader/gated_detection.py create mode 100644 app/model_downloader/hf_auth/auth_store.py create mode 100644 app/model_downloader/hf_auth/eligibility.py create mode 100644 app/model_downloader/hf_auth/oauth.py create mode 100644 app/model_downloader/hf_auth/token_store.py create mode 100644 app/model_downloader/hf_url.py create mode 100644 app/model_downloader/http_client.py create mode 100644 app/model_downloader/paths.py diff --git a/app/model_downloader/allowlist.py b/app/model_downloader/allowlist.py new file mode 100644 index 000000000..ea5accc76 --- /dev/null +++ b/app/model_downloader/allowlist.py @@ -0,0 +1,46 @@ +"""URL allowlist for server-side model fetches. + +Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows +agree on which URLs are eligible for download. Server-side allowlisting is +the primary SSRF defense for this subsystem — workflow JSON is untrusted +input (anyone can hand-craft one), so we never let the server fetch URLs +outside this list. +""" + +from urllib.parse import urlparse + +# Frontend parity: ``missingModelDownload-*.js`` exports the same triple as +# ``i = [...]`` (Civitai / HuggingFace / localhost). +_ALLOWED_URL_PREFIXES = ( + "https://huggingface.co/", + "https://civitai.com/", + "http://localhost:", + "http://127.0.0.1:", +) + +# Frontend parity: same set as ``a = [...]`` in the bundle. +_ALLOWED_MODEL_EXTENSIONS = ( + ".safetensors", + ".sft", + ".ckpt", + ".pth", + ".pt", +) + + +def is_url_allowed(url: str) -> bool: + """Check whether ``url`` is permitted as a server-side download source. + + Returns True only when both: + - the URL starts with one of the allowed prefixes, AND + - the URL's final path segment ends with a known model extension. + + Both checks are required to keep arbitrary HTML / API endpoints on + allowlisted hosts (e.g. ``https://huggingface.co/api/...``) off the table. + """ + if not isinstance(url, str) or not url: + return False + if not any(url.startswith(p) for p in _ALLOWED_URL_PREFIXES): + return False + path = urlparse(url).path + return any(path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS) diff --git a/app/model_downloader/api/routes.py b/app/model_downloader/api/routes.py new file mode 100644 index 000000000..921065540 --- /dev/null +++ b/app/model_downloader/api/routes.py @@ -0,0 +1,332 @@ +"""Aiohttp routes for the server-side model download subsystem. + +Endpoint surface (all under ``/api/``, all kebab-case): + + - ``POST /api/models-availability-status`` — bulk status + metadata query. + - ``POST /api/download-models`` — start a batch of downloads. + - ``POST /api/cancel-model-download-session`` — cancel a single in-flight one. + - ``GET /api/hf-auth-token-status`` — current HF login state. + - ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow. + - ``POST /api/hf-auth-logout`` — drop the stored HF token. + +The contract is intentionally narrow: only model_ids of the form +``/`` (validated via ``app.model_downloader.paths``) +are accepted, and only URLs on the same allowlist the frontend already +uses (HuggingFace, Civitai, localhost) can be fetched. Both are required +to keep the server out of the SSRF business for this feature. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any, Optional + +from aiohttp import web +from pydantic import BaseModel, ValidationError + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.download_server import ( + DOWNLOAD_SERVER, + DownloadSession, +) +from app.model_downloader.downloader import schedule_batch +from app.model_downloader.gated_detection import probe_url +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible +from app.model_downloader.hf_auth.oauth import ( + OAuthInProgressError, + start_login_flow, +) +from app.model_downloader.paths import ( + InvalidModelId, + parse_model_id, + resolve_existing, +) +from app.model_downloader.api import schemas_in, schemas_out + +ROUTES = web.RouteTableDef() + + +def register_routes(app: web.Application) -> None: + """Wire the model-downloader routes into the running aiohttp app. + + Called once from ``server.py`` during ``PromptServer`` startup. + """ + app.add_routes(ROUTES) + + +# ----- response helpers (same envelope as app/assets/api/routes.py) ----- + + +def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _validation_error(code: str, ve: ValidationError) -> web.Response: + return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())}) + + +def _ok(payload: BaseModel, status: int = 200) -> web.Response: + return web.json_response( + payload.model_dump(mode="json", exclude_none=False), + status=status, + ) + + +async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any: + """Parse a JSON body into a pydantic model or raise a 400 response.""" + try: + raw = await request.json() + except json.JSONDecodeError: + return _error(400, "INVALID_JSON", "Request body must be valid JSON.") + try: + return model.model_validate(raw) + except ValidationError as ve: + return _validation_error("INVALID_BODY", ve) + + +# ----- 1. availability status (unified: state + metadata per id) ----- + + +@ROUTES.post("/api/models-availability-status") +async def models_availability_status(request: web.Request) -> web.Response: + """Return per-id ``{state, progress, file_size, is_hf_downloadable}``. + + State (``available`` / ``missing`` / ``downloading``) is cheap to + recompute per call. ``file_size`` and ``is_gated`` are cached + server-side per URL. ``is_hf_downloadable`` is recomputed every + call from the current token state — that's what makes login + license + acceptance show up in the UI within one poll cycle without any + frontend cache plumbing. + """ + parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest) + if isinstance(parsed, web.Response): + return parsed + + items = list(parsed.models.items()) + + # Run all probes concurrently; each is internally cached per URL. + probes = await asyncio.gather(*(probe_url(url) for _, url in items)) + + response_models: dict[str, schemas_out.ModelStatusEntry] = {} + for (model_id, _url), probe in zip(items, probes): + try: + parse_model_id(model_id) + except InvalidModelId: + # Ill-formed identifier: report as missing without 400-ing the + # whole batch — the workflow author probably typo'd. + response_models[model_id] = schemas_out.ModelStatusEntry( + state="missing", + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + continue + + active = DOWNLOAD_SERVER.get(model_id) + if active is not None: + response_models[model_id] = schemas_out.ModelStatusEntry( + state="downloading", + progress=schemas_out.DownloadProgress( + bytes_downloaded=active.bytes_downloaded, + total_bytes=active.total_bytes, + progress=active.progress, + ), + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + continue + + state: schemas_out.ModelState = ( + "available" if resolve_existing(model_id) is not None else "missing" + ) + response_models[model_id] = schemas_out.ModelStatusEntry( + state=state, + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + + return _ok(schemas_out.AvailabilityStatusResponse( + models=response_models, + hf_auth=schemas_out.HfAuthStatus( + token_available=HF_AUTH_STORE.has_token(), + eligible=is_hf_auth_eligible(), + ), + )) + + +# ----- 2. start downloads ----- + + +@ROUTES.post("/api/download-models") +async def download_models(request: web.Request) -> web.Response: + parsed = await _parse_body(request, schemas_in.DownloadModelsRequest) + if isinstance(parsed, web.Response): + return parsed + + if not parsed.models: + return _error(400, "EMPTY_REQUEST", "No models supplied.") + + # ----- precondition pass: validate everything BEFORE registering anything ----- + # Atomic semantics: if any model fails any precondition (invalid id, + # not allow-listed URL, already on disk, already downloading, or gated), + # the entire request fails and no state is changed. + requested = list(parsed.models.items()) + + for model_id, url in requested: + try: + parse_model_id(model_id) + except InvalidModelId as e: + return _error(400, "INVALID_MODEL_ID", str(e), + {"model_id": model_id}) + + if not is_url_allowed(url): + return _error( + 400, "URL_NOT_ALLOWED", + "Server-side downloads only accept HuggingFace, Civitai, " + "or localhost URLs ending in a known model extension.", + {"model_id": model_id, "url": url}, + ) + + if resolve_existing(model_id) is not None: + return _error(409, "ALREADY_AVAILABLE", + f"Model already exists on disk: {model_id}", + {"model_id": model_id}) + + if DOWNLOAD_SERVER.is_downloading(model_id): + return _error(409, "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress.", + {"model_id": model_id}) + + # Reachability check last — it's the only one that talks to the + # network. Concurrent probes. For HF URLs ``is_hf_downloadable`` + # reflects current token access; for non-HF URLs it's None, and we + # treat that as "no info, proceed". + probes = await asyncio.gather(*(probe_url(url) for _, url in requested)) + for (model_id, url), probe in zip(requested, probes): + if probe.is_hf_downloadable is False: + return _error( + 400, "MODEL_NOT_DOWNLOADABLE", + f"Model {model_id} is gated on HuggingFace and the current " + f"server token (if any) does not grant access.", + {"model_id": model_id, "url": url}, + ) + + # ----- registration pass: try_register is atomic per model_id ----- + # Defensive: another request might have raced past our pre-check + # between the loop above and here. try_register handles that. + sessions: list[DownloadSession] = [] + for model_id, url in requested: + session = DOWNLOAD_SERVER.try_register(model_id, url) + if session is None: + # Race: someone else got in. Roll back what we registered. + for s in sessions: + DOWNLOAD_SERVER.cancel(s.model_id) + return _error(409, "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress (race).", + {"model_id": model_id}) + sessions.append(session) + + DOWNLOAD_SERVER.sweep_orphan_tmp_files() + schedule_batch(sessions) + logging.info( + "[model_downloader] scheduled %d downloads: %s", + len(sessions), [s.model_id for s in sessions], + ) + + return _ok(schemas_out.DownloadModelsResponse( + accepted=True, + scheduled=[s.model_id for s in sessions], + ), status=202) + + +# ----- 3. cancel a session ----- + + +@ROUTES.post("/api/cancel-model-download-session") +async def cancel_model_download_session(request: web.Request) -> web.Response: + parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest) + if isinstance(parsed, web.Response): + return parsed + + cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id) + if not cancelled: + return _error(404, "NOT_DOWNLOADING", + f"No active download for {parsed.model_id}.", + {"model_id": parsed.model_id}) + + return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True)) + + +# ----- 4. HuggingFace OAuth status / login start / logout ----- + + +@ROUTES.get("/api/hf-auth-token-status") +async def hf_auth_token_status(request: web.Request) -> web.Response: + """Return whether the server holds a usable HF token + its username. + + Used by the settings UI and (out-of-band) by the frontend on + login completion. ``token_available`` is true even if the cached + access_token is expired — as long as a refresh_token exists, the + user is "logged in" from their perspective. + """ + token_present = HF_AUTH_STORE.has_token() + username: Optional[str] = None + if token_present: + # Resolve the username via whoami. Done in a worker thread because + # huggingface_hub's whoami is synchronous + blocks on a network call. + tok = await HF_AUTH_STORE.get_valid_token() + if tok is not None: + try: + username = await asyncio.to_thread(_whoami_username, tok.access_token) + except Exception as e: + logging.debug("[hf_auth] whoami failed: %s", e) + return _ok(schemas_out.HfAuthTokenStatusResponse( + token_available=token_present, + username=username, + )) + + +def _whoami_username(token: str) -> Optional[str]: + """Sync helper: ask HF for the user name attached to a token.""" + from huggingface_hub import HfApi + info = HfApi().whoami(token=token) + if isinstance(info, dict): + return info.get("name") or info.get("fullname") + return None + + +@ROUTES.post("/api/hf-auth-login-start") +async def hf_auth_login_start(request: web.Request) -> web.Response: + """Begin one OAuth attempt: bind the callback port, return the URL. + + Rejected outright if this deployment isn't eligible (we don't + surface the option on multi-tenant / public-IP installs). + """ + if not is_hf_auth_eligible(): + return _error( + 403, "HF_AUTH_NOT_ELIGIBLE", + "This server is not eligible for interactive HuggingFace login. " + "It must be bound to a loopback address and not running in " + "--multi-user mode.", + ) + try: + url = await start_login_flow() + except OAuthInProgressError: + return _error( + 409, "HF_AUTH_IN_PROGRESS", + "Another HuggingFace login attempt is in progress. Try again " + "after it completes or times out.", + ) + return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url)) + + +@ROUTES.post("/api/hf-auth-logout") +async def hf_auth_logout(request: web.Request) -> web.Response: + """Drop the in-memory + on-disk HF token.""" + HF_AUTH_STORE.clear() + return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True)) diff --git a/app/model_downloader/api/schemas_in.py b/app/model_downloader/api/schemas_in.py new file mode 100644 index 000000000..d792520ea --- /dev/null +++ b/app/model_downloader/api/schemas_in.py @@ -0,0 +1,41 @@ +"""Request schemas for the model-downloader API. + +Each endpoint accepts a small JSON body. Pydantic enforces the shape at +the boundary; route handlers operate only on validated values past that. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class AvailabilityStatusRequest(BaseModel): + """``POST /api/models-availability-status``. + + Sent by the frontend on each poll. Each entry is ``{model_id: url}``; + the URL is the one declared in ``properties.models[i].url`` in the + workflow JSON and lets the server compute per-id metadata + (``file_size`` + ``is_hf_downloadable``) on the same request. + """ + models: dict[str, str] = Field(default_factory=dict) + + +class DownloadModelsRequest(BaseModel): + """``POST /api/download-models``. + + Same shape as the metadata request — the URL for each model_id. + Returns immediately after validation and scheduling. + """ + models: dict[str, str] = Field(default_factory=dict) + + +class CancelDownloadSessionRequest(BaseModel): + """``POST /api/cancel-model-download-session``.""" + model_id: str + + +__all__ = [ + "AvailabilityStatusRequest", + "DownloadModelsRequest", + "CancelDownloadSessionRequest", +] diff --git a/app/model_downloader/api/schemas_out.py b/app/model_downloader/api/schemas_out.py new file mode 100644 index 000000000..5571ecd28 --- /dev/null +++ b/app/model_downloader/api/schemas_out.py @@ -0,0 +1,81 @@ +"""Response schemas for the model-downloader API.""" + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel + + +ModelState = Literal["available", "missing", "downloading"] + + +class DownloadProgress(BaseModel): + """Embedded in a model entry when its state is ``downloading``.""" + bytes_downloaded: int + total_bytes: Optional[int] = None + progress: Optional[float] = None # fraction in [0,1]; null until total known + + +class ModelStatusEntry(BaseModel): + """Everything the UI needs to render one row, in one shot. + + ``state`` reflects what the server has on disk + in-flight; ``file_size`` + and ``is_hf_downloadable`` come from probes (intrinsic; cached). + The HF fields are populated for every poll (cached on the server), + so license-acceptance flips show up within one poll interval without + any frontend cache invalidation. + """ + state: ModelState + progress: Optional[DownloadProgress] = None + file_size: Optional[int] = None + # HF-only: True iff the server can fetch this URL with current auth + # state. False iff gated and lacking access. None for non-HF URLs. + is_hf_downloadable: Optional[bool] = None + + +class HfAuthStatus(BaseModel): + """Snapshot of HF login state, embedded in availability response.""" + token_available: bool + eligible: bool + + +class AvailabilityStatusResponse(BaseModel): + models: dict[str, ModelStatusEntry] + hf_auth: HfAuthStatus + + +class DownloadModelsResponse(BaseModel): + accepted: bool + scheduled: list[str] + + +class CancelDownloadSessionResponse(BaseModel): + cancelled: bool + + +class HfAuthTokenStatusResponse(BaseModel): + token_available: bool + username: Optional[str] = None + + +class HfAuthLoginStartResponse(BaseModel): + authorize_url: str + + +class HfAuthLogoutResponse(BaseModel): + logged_out: bool + + +__all__ = [ + "ModelState", + "DownloadProgress", + "ModelStatusEntry", + "HfAuthStatus", + "AvailabilityStatusResponse", + "DownloadModelsResponse", + "CancelDownloadSessionResponse", + "HfAuthTokenStatusResponse", + "HfAuthLoginStartResponse", + "HfAuthLogoutResponse", +] diff --git a/app/model_downloader/download_server.py b/app/model_downloader/download_server.py new file mode 100644 index 000000000..c323309d8 --- /dev/null +++ b/app/model_downloader/download_server.py @@ -0,0 +1,179 @@ +"""Process-wide registry of in-flight model downloads. + +A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running +server-side model fetch. Designed to be safe with multiple concurrent +clients hitting the API: each model_id has at most one active session, +and the API rejects requests that conflict with in-flight downloads. + +Cancellation is cooperative — the download loop checks ``is_active`` on +its own session between chunks and raises ``DownloadCancelled`` when the +session has been removed from the registry. This avoids the complications +of ``Task.cancel()`` from outside the loop while still giving deterministic +rollback semantics (the worker is responsible for deleting its own +``.tmp`` on the cancel path). +""" + +from __future__ import annotations + +import logging +import os +import threading +from dataclasses import dataclass, field +from typing import Optional + +from app.model_downloader.paths import iter_all_tmp_paths + + +class DownloadCancelled(Exception): + """Raised by the streaming loop when its session has been removed + from the registry (cancellation request) and the worker should roll + back its ``.tmp`` file.""" + + +@dataclass +class DownloadSession: + """One in-flight download. + + ``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first + byte arrives and we know whether the response carries a + ``Content-Length``. ``total_bytes`` mirrors that header when present. + """ + model_id: str + url: str + progress: Optional[float] = None + bytes_downloaded: int = 0 + total_bytes: Optional[int] = None + # Sequence number used solely as identity for the cancellation check — + # so that "cancel + restart" doesn't get confused by stale workers. + epoch: int = field(default_factory=lambda: 0) + + +class DownloadServer: + """Singleton registry of active downloads. + + All mutation goes through this object so concurrent route handlers + see a consistent view. The ``_lock`` is a plain threading lock + because the registry is consulted from both the asyncio event-loop + thread (route handlers) and from any worker coroutines spawned to + perform downloads. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._sessions: dict[str, DownloadSession] = {} + self._epoch_counter = 0 + self._orphan_sweep_done = False + + # ----- lifecycle ----- + + def sweep_orphan_tmp_files(self) -> None: + """Idempotently sweep ``*.tmp`` files left by crashed downloads. + + Deferred off the import path so module load doesn't block on + filesystem I/O against potentially-slow mounts. Each route handler + that might create a new ``.tmp`` runs this exactly once. + """ + with self._lock: + if self._orphan_sweep_done: + return + self._orphan_sweep_done = True + for path in iter_all_tmp_paths(): + try: + os.remove(path) + logging.info("[model_downloader] removed orphan tmp file: %s", path) + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + # ----- queries ----- + + def is_downloading(self, model_id: str) -> bool: + with self._lock: + return model_id in self._sessions + + def get(self, model_id: str) -> Optional[DownloadSession]: + with self._lock: + return self._sessions.get(model_id) + + def snapshot(self) -> dict[str, DownloadSession]: + """Return a shallow copy of the current sessions map.""" + with self._lock: + return dict(self._sessions) + + # ----- mutations ----- + + def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]: + """Atomically register a new session iff none exists for ``model_id``. + + Returns the new session on success, ``None`` if a session is already + in flight. Callers must check the return value — the caller is the + sole owner of the session it gets back. + """ + with self._lock: + if model_id in self._sessions: + return None + self._epoch_counter += 1 + session = DownloadSession( + model_id=model_id, + url=url, + epoch=self._epoch_counter, + ) + self._sessions[model_id] = session + return session + + def update_progress( + self, + session: DownloadSession, + bytes_downloaded: int, + total_bytes: Optional[int], + ) -> None: + """Update progress on a session. No-op if the session has been + removed (cancelled) — caller should check ``is_active`` separately.""" + with self._lock: + current = self._sessions.get(session.model_id) + if current is None or current.epoch != session.epoch: + return + current.bytes_downloaded = bytes_downloaded + current.total_bytes = total_bytes + if total_bytes and total_bytes > 0: + current.progress = min(1.0, bytes_downloaded / total_bytes) + + def is_active(self, session: DownloadSession) -> bool: + """True iff this exact session is still the registered one for + its model_id. False after cancellation, after completion, or if + another session has replaced it.""" + with self._lock: + current = self._sessions.get(session.model_id) + return current is not None and current.epoch == session.epoch + + def finish(self, session: DownloadSession) -> None: + """Remove a completed (or cancelled) session from the registry. + + Safe to call multiple times. Only removes if the epoch matches — + we never accidentally evict a *newer* session for the same model_id. + """ + with self._lock: + current = self._sessions.get(session.model_id) + if current is not None and current.epoch == session.epoch: + del self._sessions[session.model_id] + + def reset_for_tests(self) -> None: + """Clear all sessions and reset the epoch counter. Test-only.""" + with self._lock: + self._sessions.clear() + self._epoch_counter = 0 + + def cancel(self, model_id: str) -> bool: + """Remove the session registered for ``model_id``. + + Returns True if there was an active session to cancel. The worker + will discover the cancellation on its next ``is_active`` check + and roll back its ``.tmp`` file. + """ + with self._lock: + if model_id in self._sessions: + del self._sessions[model_id] + return True + return False + + +DOWNLOAD_SERVER = DownloadServer() diff --git a/app/model_downloader/downloader.py b/app/model_downloader/downloader.py new file mode 100644 index 000000000..091a4a398 --- /dev/null +++ b/app/model_downloader/downloader.py @@ -0,0 +1,205 @@ +"""Streaming download worker with progress reporting and cancellation. + +Each download writes to ``.tmp`` and atomically renames into +place on success. Between chunks the worker checks the registry for +cancellation (via ``DownloadServer.is_active``) and rolls back its +``.tmp`` on cancel or on any error. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + +import aiohttp + +from app.model_downloader.download_server import ( + DOWNLOAD_SERVER, + DownloadCancelled, + DownloadSession, +) +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_url import is_hf_url +from app.model_downloader.http_client import get_session, parse_content_length +from app.model_downloader.paths import resolve_destination + + +CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths. +REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120) + + +async def stream_to_disk(session: DownloadSession) -> str: + """Run a single download to completion or cancellation. + + Returns the final on-disk path on success. Removes the ``.tmp`` and + raises on cancellation or failure. The session is finished + (removed from the registry) exactly once, here — callers do not + need to call ``DOWNLOAD_SERVER.finish`` themselves. + """ + final_path, tmp_path = resolve_destination(session.model_id) + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Wipe any stale .tmp from a previous failed attempt before we start — + # otherwise a partial body could masquerade as our completed download + # when the rename finally happens. + _remove_if_exists(tmp_path) + + bytes_seen = 0 + try: + http = await get_session() + headers = _auth_headers_for(session.url) + logging.info( + "[model_downloader] starting GET %s (auth=%s)", + session.url, "yes" if "Authorization" in headers else "no", + ) + async with http.get( + session.url, + allow_redirects=True, + timeout=REQUEST_TIMEOUT, + headers=headers, + ) as resp: + if resp.status != 200: + # Capture a snippet of the response body so 4xx/5xx aren't + # opaque in the logs — HF returns JSON or HTML with a + # human-readable reason on failures. + body_snippet = await _read_short(resp) + logging.warning( + "[model_downloader] GET %s failed: status=%d final_url=%s body=%s", + session.url, resp.status, str(resp.url), body_snippet, + ) + raise DownloadError( + f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}", + status=resp.status, + ) + + total = parse_content_length(resp.headers.get("Content-Length")) + DOWNLOAD_SERVER.update_progress(session, 0, total) + + with open(tmp_path, "wb") as f: + async for chunk in resp.content.iter_chunked(CHUNK_SIZE): + # Cancellation check between chunks. Cheap and means + # cancellation latency is bounded by one chunk plus + # one ``write()`` — typically well under a second + # even on slow disks. + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled() + f.write(chunk) + bytes_seen += len(chunk) + DOWNLOAD_SERVER.update_progress(session, bytes_seen, total) + + # Final cancellation check before we promote the .tmp to the real + # filename — avoids the awkward case where cancel arrives during + # the very last chunk and we'd otherwise commit anyway. + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled() + + # Atomic rename. os.replace is atomic within the same filesystem, + # which is guaranteed here because tmp lives alongside final_path. + os.replace(tmp_path, final_path) + logging.info( + "[model_downloader] downloaded %s (%d bytes) from %s", + session.model_id, bytes_seen, session.url, + ) + return final_path + + except DownloadCancelled: + logging.info("[model_downloader] cancelled: %s", session.model_id) + _remove_if_exists(tmp_path) + raise + except Exception as e: + logging.warning( + "[model_downloader] failed: %s from %s: %s: %s", + session.model_id, session.url, type(e).__name__, e, + exc_info=True, + ) + _remove_if_exists(tmp_path) + raise + finally: + # In all terminal states (success / cancel / error) drop the + # session from the registry. Idempotent — only removes if we're + # still the live epoch for this model_id. + DOWNLOAD_SERVER.finish(session) + + +class DownloadError(Exception): + """Network / protocol error during a download.""" + + def __init__(self, message: str, status: Optional[int] = None) -> None: + super().__init__(message) + self.status = status + + +async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str: + """Read up to ``limit`` bytes of a response body for logging. + + Used to surface the JSON/HTML reason from an HF non-2xx response in + server logs instead of just the status code. Best-effort: any + error here is swallowed. + """ + try: + raw = await resp.content.read(limit) + return raw.decode("utf-8", errors="replace").strip() + except Exception: + return "" + + +def _auth_headers_for(url: str) -> dict[str, str]: + """Return any auth headers we should add to the GET for ``url``. + + For HuggingFace URLs we inject the user's OAuth access token as a + Bearer header — this is HF's documented way to access gated repos + (see ``huggingface_hub.hf_hub_download``'s wire format). For every + other host we send no extra headers; allowlisted public files + don't need them and we don't want to leak tokens to other hosts. + """ + if not is_hf_url(url): + return {} + tok = HF_AUTH_STORE.get_token_sync() + if tok is None or not tok.access_token: + return {} + return {"Authorization": f"Bearer {tok.access_token}"} + + +def _remove_if_exists(path: str) -> None: + try: + os.remove(path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + +async def run_batch_sequential(sessions: list[DownloadSession]) -> None: + """Run a list of sessions one after the other. + + Each session is independent: a failure or cancellation of one does + not abort the rest. Cancellations are observable via the registry + *before* a given download starts, so a session that's been + pre-cancelled (cancel before the worker reached it) just gets skipped. + """ + for session in sessions: + # If the session got cancelled before its turn, skip without + # touching disk. This is what makes the per-request "sequential + # but cancellable" semantic work. + if not DOWNLOAD_SERVER.is_active(session): + DOWNLOAD_SERVER.finish(session) + continue + try: + await stream_to_disk(session) + except DownloadCancelled: + # Already logged + tmp removed inside stream_to_disk. + continue + except Exception: + # stream_to_disk already logged. Continue with the rest of the batch. + continue + + +def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task: + """Kick off ``run_batch_sequential`` on the running event loop. + + Returned task is fire-and-forget; the API handler returns immediately + after scheduling and clients observe progress via the polling endpoints. + """ + return asyncio.create_task(run_batch_sequential(sessions)) diff --git a/app/model_downloader/gated_detection.py b/app/model_downloader/gated_detection.py new file mode 100644 index 000000000..1dc2d3545 --- /dev/null +++ b/app/model_downloader/gated_detection.py @@ -0,0 +1,238 @@ +"""Per-URL probes for the unified availability endpoint. + +Three cached/derived facts per URL: + + - ``is_gated`` intrinsic to the model; cached forever once known. + Determined by ``auth_check(repo_id, token=None)``: + ``GatedRepoError`` → True, success → False. + + - ``is_hf_downloadable`` depends on the *current* token; recomputed every + call. For non-gated URLs this is trivially True + (no HF call needed). For gated URLs we run + ``auth_check`` with the stored token each call. + + - ``file_size`` intrinsic to the file. Cached forever once + determined (including ``None`` on transient + failure — we don't retry). We only attempt the + HEAD when we already know the URL is downloadable + to us; that way a failed-because-gated probe + never lands as a cached ``None``. + +Caches are per-process, in-memory; small, no eviction needed for the +workflow-scale (~tens of URLs). Concurrent calls for the same URL +deduplicate via per-URL ``asyncio.Lock``. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import Optional + +import aiohttp +from huggingface_hub import HfApi +from huggingface_hub.errors import ( + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, +) + +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_url import is_hf_url, repo_id_from_url +from app.model_downloader.http_client import get_session, parse_content_length + + +_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15) + + +@dataclass +class ProbeResult: + file_size: Optional[int] + is_hf_downloadable: Optional[bool] + + +# --- caches -------------------------------------------------------------- # + + +# url → bool. Whether this URL's HF repo gates access. Intrinsic to the +# model — never changes for a given URL. +_is_gated_cache: dict[str, bool] = {} + +# url → Optional[int]. The file's size in bytes, ``None`` if a probe +# was attempted and produced no answer. **Only populated when we knew +# the URL was downloadable to us at probe time** — so gated-without- +# access never lands a ``None`` here that we'd be stuck with after login. +_file_size_cache: dict[str, Optional[int]] = {} + +# Per-URL locks for single-flight probes — when multiple polls arrive +# in the same tick for the same URL, exactly one of them runs the HF +# call and the others wait on the result. +_locks: dict[str, asyncio.Lock] = {} + + +def _lock_for(url: str) -> asyncio.Lock: + lock = _locks.get(url) + if lock is None: + lock = asyncio.Lock() + _locks[url] = lock + return lock + + +def clear_caches_for_tests() -> None: + """Test-only: drop everything.""" + _is_gated_cache.clear() + _file_size_cache.clear() + _locks.clear() + + +# --- public entrypoint --------------------------------------------------- # + + +async def probe_url(url: str) -> ProbeResult: + """Return downloadability + size for one URL, using caches where safe.""" + if not is_hf_url(url): + # Non-HF: ``is_hf_downloadable`` is "not applicable" (None). + # Size we still cache so we don't HEAD on every poll. + size = await _get_or_probe_size(url, token=None) + return ProbeResult(file_size=size, is_hf_downloadable=None) + + repo_id = repo_id_from_url(url) + if repo_id is None: + return ProbeResult(file_size=None, is_hf_downloadable=None) + + # Determine intrinsic gating once. + gated = await _resolve_is_gated(url, repo_id) + if gated is None: + return ProbeResult(file_size=None, is_hf_downloadable=None) + + # Compute current-token downloadability per call. + tok = HF_AUTH_STORE.get_token_sync() + token_str: Optional[str] = tok.access_token if tok else None + if not gated: + is_hf_downloadable: Optional[bool] = True + else: + is_hf_downloadable = await _auth_check_with_token(repo_id, token_str) + + if is_hf_downloadable is True: + size = await _get_or_probe_size(url, token=token_str) + else: + # Skip the HEAD entirely — would 401 and we'd be stuck with + # cached None that survives a later login. + size = None + + return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable) + + +# --- gated/auth probes --------------------------------------------------- # + + +async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]: + """Decide once whether ``repo_id`` is a gated repo.""" + cached = _is_gated_cache.get(url) + if cached is not None: + return cached + + async with _lock_for(url): + cached = _is_gated_cache.get(url) + if cached is not None: + return cached + try: + await asyncio.to_thread(_auth_check_sync, repo_id, None) + _is_gated_cache[url] = False + return False + except GatedRepoError: + _is_gated_cache[url] = True + return True + except RepositoryNotFoundError: + # Repo doesn't exist publicly. Treat as gated — we can't + # serve it without auth, and an authenticated check might + # still succeed if it's a private repo the user can see. + _is_gated_cache[url] = True + return True + except (HfHubHTTPError, Exception) as e: + logging.debug( + "[hf_auth] is_gated probe failed for %s (will retry): %s", + repo_id, e, + ) + return None # don't cache; retry next call + + +async def _auth_check_with_token( + repo_id: str, token: Optional[str] +) -> Optional[bool]: + """Auth-check with the supplied token. True/False/None per outcome.""" + try: + await asyncio.to_thread(_auth_check_sync, repo_id, token) + return True + except GatedRepoError: + return False + except RepositoryNotFoundError: + return False + except HfHubHTTPError as e: + # 401/403 covers org-SSO-required, revoked tokens, and similar — + # all of which mean "can't fetch right now" from the user's POV. + status = getattr(getattr(e, "response", None), "status_code", None) + if status in (401, 403): + return False + logging.debug( + "[hf_auth] auth_check transient failure for %s: %s", repo_id, e, + ) + return None + except Exception as e: + logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e) + return None + + +def _auth_check_sync(repo_id: str, token: Optional[str]) -> None: + """Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``.""" + HfApi().auth_check(repo_id, token=token) + + +# --- size probe ---------------------------------------------------------- # + + +async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]: + """Return the cached size or HEAD the URL once and cache the result.""" + if url in _file_size_cache: + return _file_size_cache[url] + + async with _lock_for(url): + if url in _file_size_cache: + return _file_size_cache[url] + size = await _probe_size_once(url, token=token) + _file_size_cache[url] = size + return size + + +async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]: + """HEAD the URL and return the file size in bytes, or None on failure. + + HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL. + The real file size lives in the non-standard ``X-Linked-Size`` header + on that 302 response (``Content-Length`` is the redirect-body length). + Disabling redirect-follow lets us read either header on the same + response: + + - LFS files: 302 + ``X-Linked-Size`` + - Small/non-LFS files: 200 + ``Content-Length`` + """ + headers = {"Authorization": f"Bearer {token}"} if token else {} + try: + session = await get_session() + async with session.head( + url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers, + ) as resp: + linked = parse_content_length(resp.headers.get("X-Linked-Size")) + if linked is not None: + return linked + if resp.status == 200: + return parse_content_length(resp.headers.get("Content-Length")) + return None + except (aiohttp.ClientError, TimeoutError, OSError): + return None + + +# Backward-compat shim so consumers that still import the old name keep +# building during the refactor; can be removed once routes are updated. +MetadataProbeResult = ProbeResult diff --git a/app/model_downloader/hf_auth/auth_store.py b/app/model_downloader/hf_auth/auth_store.py new file mode 100644 index 000000000..2dcefdfe7 --- /dev/null +++ b/app/model_downloader/hf_auth/auth_store.py @@ -0,0 +1,106 @@ +"""In-memory token cache with lazy disk persistence + refresh. + +Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask +``get_valid_token()``; the store transparently refreshes from disk +on first use, refreshes via the OAuth refresh_token if the cached +access_token is expired, and returns ``None`` if neither path works. + +The refresh path imports ``oauth.refresh_access_token`` lazily to +avoid an import cycle (oauth needs the store to save tokens it +acquires). +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from typing import Optional + +from app.model_downloader.hf_auth.token_store import ( + Token, + delete_token, + load_token, + save_token, +) + + +class HfAuthStore: + def __init__(self) -> None: + self._lock = threading.Lock() + self._token: Optional[Token] = None + self._loaded_from_disk = False + + def _ensure_loaded(self) -> None: + """Read the disk token into memory on first access.""" + if self._loaded_from_disk: + return + with self._lock: + if self._loaded_from_disk: + return + self._token = load_token() + self._loaded_from_disk = True + + def has_token(self) -> bool: + """Cheap check: is there any token in memory? + + Does not attempt refresh; an expired-but-refreshable token still + counts as "logged in" from the user's perspective. + """ + self._ensure_loaded() + return self._token is not None + + def set_token(self, token: Token) -> None: + """Replace the in-memory token and persist to disk.""" + with self._lock: + self._token = token + self._loaded_from_disk = True + save_token(token) + + def clear(self) -> None: + """Forget the token in memory and on disk (logout).""" + with self._lock: + self._token = None + self._loaded_from_disk = True + delete_token() + + def get_token_sync(self) -> Optional[Token]: + """Return the cached token without refreshing. + + Sync callers (e.g. constructing an Authorization header in a + non-async path) use this. They accept an expired token over + ``None``; HF will simply return 401 and the caller can decide + what to do. + """ + self._ensure_loaded() + return self._token + + async def get_valid_token(self) -> Optional[Token]: + """Return a fresh token, refreshing via OAuth if necessary. + + Returns ``None`` if there's no cached token at all, or if the + cached token is expired and refresh failed. Callers should + treat that as "user is not logged in". + """ + self._ensure_loaded() + tok = self._token + if tok is None: + return None + if tok.is_valid(): + return tok + if not tok.refresh_token: + return None + + # Lazy import to avoid the oauth ↔ store import cycle. + from app.model_downloader.hf_auth.oauth import refresh_access_token + + try: + refreshed = await refresh_access_token(tok.refresh_token) + except Exception as e: + logging.warning("[hf_auth] token refresh failed: %s", e) + return None + self.set_token(refreshed) + return refreshed + + +HF_AUTH_STORE = HfAuthStore() diff --git a/app/model_downloader/hf_auth/eligibility.py b/app/model_downloader/hf_auth/eligibility.py new file mode 100644 index 000000000..ad788e4cd --- /dev/null +++ b/app/model_downloader/hf_auth/eligibility.py @@ -0,0 +1,55 @@ +"""Whether this deployment is allowed to do interactive HF OAuth. + +We only let the server hold a HuggingFace token under a strict trust +assumption: this is a *single tenant local* install. Concretely: + + - The server is bound to a loopback address. SSH tunneling / + reverse-proxies can defeat this, but it's the strongest signal + we have without an authentication system. + - ``--multi-user`` is off. A shared token used implicitly by multiple + declared users would be a footgun — one user's gated downloads + would silently authenticate as another. + +Anything else and the frontend hides the HF login UI entirely; gated +models continue to show the "acquire it manually" message. +""" + +from __future__ import annotations + +import ipaddress +import socket + +from comfy.cli_args import args + + +def _is_loopback(host: str | None) -> bool: + """Duplicates ``server.is_loopback`` (small, no shared module yet). + + Resolves a host or IP literal to whether it lives on the loopback + interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for + ``0.0.0.0`` / ``::`` because those are bind-all wildcards, not + loopback. + """ + if host is None: + return False + try: + return ipaddress.ip_address(host).is_loopback + except ValueError: + pass + + loopback = False + for family in (socket.AF_INET, socket.AF_INET6): + try: + r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM) + for _family, _, _, _, sockaddr in r: + if not ipaddress.ip_address(sockaddr[0]).is_loopback: + return loopback + loopback = True + except socket.gaierror: + pass + return loopback + + +def is_hf_auth_eligible() -> bool: + """True iff this deployment may surface the HF OAuth flow.""" + return _is_loopback(args.listen) and not args.multi_user diff --git a/app/model_downloader/hf_auth/oauth.py b/app/model_downloader/hf_auth/oauth.py new file mode 100644 index 000000000..7cff05328 --- /dev/null +++ b/app/model_downloader/hf_auth/oauth.py @@ -0,0 +1,277 @@ +"""OAuth 2.0 PKCE flow against HuggingFace's authorization server. + +Wired so that ``POST /api/hf-auth-login-start`` can: + 1. Generate state + PKCE verifier/challenge in this process. + 2. Spin up a short-lived loopback HTTP server at port 41954 to + receive the redirect callback from HF. + 3. Return the ``authorize_url`` for the frontend to open in a new tab. + +After the user grants consent on huggingface.co, HF redirects to the +local callback URL with ``code`` and ``state``. The callback server +validates ``state`` (CSRF), exchanges the code for tokens via PKCE, +hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts +itself down. + +Before this can be exercised end-to-end a maintainer must register a +HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder +below. See the comment above the constant for the exact steps. +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import logging +import secrets +import threading +import time + +import aiohttp +from aiohttp import web + +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_auth.token_store import Token +from app.model_downloader.http_client import ssl_context + + +# --- HF OAuth app registration -------------------------------------------- # +# NOTE: The OAuth client_id below is a placeholder. Before this feature can be +# exercised end-to-end, a maintainer must register a HuggingFace OAuth app +# under a Comfy-Org-controlled HF account and substitute its client_id here. +# Detailed walkthrough is in docs/server-side-model-downloads-handover.html +# ("HuggingFace OAuth app setup" section). Short version: +# 1. huggingface.co → Settings → Connected Apps → "Create app" +# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and +# ``gated-repos`` (Repository Access). Leave everything else off. +# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback`` +# — must match ``REDIRECT_URI`` below; change both in lockstep if you +# change ``CALLBACK_PORT``. +# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below. +# The client_id is not a secret (it travels through the user's browser in +# plaintext); HF's "Public app" type means there's no client secret to +# manage — PKCE replaces it. +HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID" + +CALLBACK_HOST = "127.0.0.1" +CALLBACK_PORT = 41954 +CALLBACK_PATH = "/api/auth/huggingface/callback" +REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}" + +AUTHORIZE_URL = "https://huggingface.co/oauth/authorize" +TOKEN_URL = "https://huggingface.co/oauth/token" +# Minimal scope set for the feature: +# - openid : required by HF when the app uses OIDC at all +# - profile : lets ``HfApi.whoami(token=...)`` return a username for the +# settings UI; cosmetic but expected +# - gated-repos : grants the token enough to call ``auth_check`` and +# download files from public gated repos the user has +# accepted the license for. The wider ``read-repos`` scope +# would also work (it includes ``gated-repos``) but it +# additionally grants private-repo read access, which we +# don't need and which makes the consent screen scarier +# for the user. +SCOPE = "openid profile gated-repos" + +# Maximum time the callback server stays up waiting for the user to +# complete consent on huggingface.co. Past this, the port closes and +# the user has to click "Log in" again. +CALLBACK_TIMEOUT_SECS = 300 + + +# Process-wide lock so two simultaneous /api/hf-auth-login-start +# requests don't fight over port CALLBACK_PORT. +_OAUTH_LOCK = threading.Lock() + + +class OAuthInProgressError(Exception): + """Another OAuth attempt is already running.""" + + +class OAuthCallbackError(Exception): + """The OAuth callback returned an error (HF denied, port stolen, etc.).""" + + +# --- PKCE primitives ------------------------------------------------------ # + + +def _make_pkce() -> tuple[str, str, str]: + """Return ``(verifier, challenge, state)``. + + Verifier never leaves this process. Challenge and state travel + through the user's browser. State is checked on the callback to + prevent a malicious cross-origin redirect from injecting a token. + """ + verifier = secrets.token_urlsafe(64) + challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) + .rstrip(b"=") + .decode("ascii") + ) + state = secrets.token_urlsafe(32) + return verifier, challenge, state + + +def _build_authorize_url(challenge: str, state: str) -> str: + from urllib.parse import urlencode + + params = { + "client_id": HF_CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "response_type": "code", + "scope": SCOPE, + "state": state, + "code_challenge": challenge, + "code_challenge_method": "S256", + } + return f"{AUTHORIZE_URL}?{urlencode(params)}" + + +# --- Token exchange ------------------------------------------------------- # + + +async def _exchange_code(code: str, verifier: str) -> Token: + """Trade the authorization code for an access+refresh token pair.""" + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": REDIRECT_URI, + "client_id": HF_CLIENT_ID, + "code_verifier": verifier, + } + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp: + resp.raise_for_status() + body = await resp.json() + return Token( + access_token=body["access_token"], + refresh_token=body.get("refresh_token"), + expires_at=time.time() + float(body.get("expires_in", 3600)), + scope=body.get("scope", SCOPE), + ) + + +async def refresh_access_token(refresh_token: str) -> Token: + """Trade a refresh_token for a new access (+ possibly refresh) token.""" + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": HF_CLIENT_ID, + } + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp: + resp.raise_for_status() + body = await resp.json() + return Token( + access_token=body["access_token"], + # If HF doesn't rotate refresh tokens, keep using the existing one. + refresh_token=body.get("refresh_token", refresh_token), + expires_at=time.time() + float(body.get("expires_in", 3600)), + scope=body.get("scope", SCOPE), + ) + + +# --- Callback server ------------------------------------------------------ # + + +async def start_login_flow() -> str: + """Begin one OAuth attempt: spawn the callback server, return the URL. + + Returns the URL the frontend should open in a new tab. Raises + ``OAuthInProgressError`` if another attempt is already running. + The callback server runs in the background until the user + completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses; + either way the lock + port are released afterward. + """ + if not _OAUTH_LOCK.acquire(blocking=False): + raise OAuthInProgressError() + + verifier, challenge, state = _make_pkce() + authorize_url = _build_authorize_url(challenge, state) + + # Fire the callback server on the running loop and return. + asyncio.create_task(_run_callback_server(verifier, state)) + return authorize_url + + +async def _run_callback_server(verifier: str, expected_state: str) -> None: + """Listen for HF's redirect once, capture the token, then shut down.""" + received: asyncio.Future[Token] = asyncio.get_event_loop().create_future() + + async def handler(request: web.Request) -> web.Response: + try: + if request.query.get("state") != expected_state: + return web.Response(status=400, text="state mismatch") + err = request.query.get("error") + if err: + received.set_exception(OAuthCallbackError(f"HF returned: {err}")) + return web.Response(status=400, text=f"OAuth error: {err}") + code = request.query.get("code") + if not code: + return web.Response(status=400, text="missing code") + tok = await _exchange_code(code, verifier) + if not received.done(): + received.set_result(tok) + return web.Response( + content_type="text/html", + text=( + "" + "

HuggingFace login successful

" + "

You can close this tab and return to ComfyUI.

" + "" + ), + ) + except Exception as exc: + if not received.done(): + received.set_exception(exc) + return web.Response(status=500, text=str(exc)) + + app = web.Application() + app.router.add_get(CALLBACK_PATH, handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True) + try: + await site.start() + except OSError as e: + # Port already in use (or some other socket-bind failure). Release + # the lock so a future attempt has a chance to succeed. + logging.warning("[hf_auth] could not bind callback port: %s", e) + _OAUTH_LOCK.release() + return + + try: + token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS) + except asyncio.TimeoutError: + logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS) + return + except OAuthCallbackError as e: + logging.warning("[hf_auth] OAuth callback error: %s", e) + return + except Exception as e: + logging.warning("[hf_auth] unexpected OAuth failure: %s", e) + return + else: + HF_AUTH_STORE.set_token(token) + logging.info("[hf_auth] OAuth login complete") + finally: + await runner.cleanup() + if _OAUTH_LOCK.locked(): + _OAUTH_LOCK.release() + + +def is_login_in_progress() -> bool: + """True iff a callback server is currently bound + waiting.""" + return _OAUTH_LOCK.locked() + + +# Re-export for callers that only want the URL builder (e.g. tests). +__all__ = [ + "start_login_flow", + "refresh_access_token", + "is_login_in_progress", + "OAuthInProgressError", + "CALLBACK_TIMEOUT_SECS", +] diff --git a/app/model_downloader/hf_auth/token_store.py b/app/model_downloader/hf_auth/token_store.py new file mode 100644 index 000000000..28e0288f5 --- /dev/null +++ b/app/model_downloader/hf_auth/token_store.py @@ -0,0 +1,89 @@ +"""On-disk persistence for the HuggingFace OAuth token. + +The token shape mirrors what HF returns on the token exchange: an +``access_token``, an optional ``refresh_token``, the absolute epoch at +which the access token expires, and the granted scope. We persist +this so logging in once survives ComfyUI restarts; the file is mode +``0600`` so only the OS user can read it. +""" + +from __future__ import annotations + +import json +import logging +import os +import stat +import time +from dataclasses import asdict, dataclass +from typing import Optional + +import folder_paths + + +# Treat a token as expired this many seconds before its server-reported +# ``expires_at`` so we don't try to use a token mid-request only for it +# to flip stale between auth_check and the actual GET. +EXPIRY_BUFFER_SECS = 60 + +TOKEN_FILENAME = "hf_auth_token.json" + + +@dataclass +class Token: + """One OAuth token + the metadata we need to use it.""" + access_token: str + refresh_token: Optional[str] + expires_at: float # absolute epoch seconds + scope: str = "" + + def is_valid(self) -> bool: + """True iff we can use this token right now.""" + return ( + bool(self.access_token) + and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS) + ) + + +def _token_path() -> str: + base = folder_paths.get_user_directory() + return os.path.join(base, TOKEN_FILENAME) + + +def load_token() -> Optional[Token]: + """Read the persisted token, returning ``None`` if absent or corrupt.""" + path = _token_path() + if not os.path.exists(path): + return None + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return Token(**data) + except (OSError, json.JSONDecodeError, TypeError) as e: + logging.warning("[hf_auth] could not load token at %s: %s", path, e) + return None + + +def save_token(token: Token) -> None: + """Atomically write the token with 0600 permissions.""" + path = _token_path() + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = path + ".tmp" + with open(tmp, "w", encoding="utf-8") as f: + json.dump(asdict(token), f) + os.replace(tmp, path) + try: + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) + except OSError as e: + # On Windows / weird filesystems chmod may be a no-op; not fatal. + logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e) + + +def delete_token() -> None: + """Remove the persisted token; no-op if it doesn't exist.""" + path = _token_path() + try: + os.remove(path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning("[hf_auth] could not remove token at %s: %s", path, e) diff --git a/app/model_downloader/hf_url.py b/app/model_downloader/hf_url.py new file mode 100644 index 000000000..c305d5ac0 --- /dev/null +++ b/app/model_downloader/hf_url.py @@ -0,0 +1,41 @@ +"""Parsers for the ``huggingface.co`` URL shape we accept in workflows. + +The download API accepts URLs of the form +``https://huggingface.co///resolve//``. +We need to recover ``/`` (the *repo_id*) from such URLs for +``huggingface_hub`` API calls (notably ``HfApi.auth_check``). +""" + +from __future__ import annotations + +from typing import Optional +from urllib.parse import urlparse + +_HF_HOST = "huggingface.co" + + +def is_hf_url(url: str) -> bool: + """Cheap host check — does this URL point at huggingface.co?""" + try: + return urlparse(url).hostname == _HF_HOST + except ValueError: + return False + + +def repo_id_from_url(url: str) -> Optional[str]: + """Extract ``/`` from an HF model file URL. + + Returns ``None`` if the URL isn't on huggingface.co or doesn't look + like a model-file URL. The expected shape is + ``///resolve//`` — anything else + (datasets, spaces, /tree/, /blob/, …) we treat as out of scope here. + """ + if not is_hf_url(url): + return None + parts = urlparse(url).path.lstrip("/").split("/") + if len(parts) < 4 or parts[2] != "resolve": + return None + org, repo = parts[0], parts[1] + if not org or not repo: + return None + return f"{org}/{repo}" diff --git a/app/model_downloader/http_client.py b/app/model_downloader/http_client.py new file mode 100644 index 000000000..4c2b81dc6 --- /dev/null +++ b/app/model_downloader/http_client.py @@ -0,0 +1,63 @@ +"""Lazy module-level aiohttp ClientSession. + +A single shared session means TLS handshakes are reused across HEAD probes +and the subsequent GETs to the same host (HuggingFace is the dominant +case), which is a noticeable speedup on cold connections. + +We deliberately don't close the session at process exit — aiohttp's +warning about unclosed sessions is benign at shutdown, and adding atexit +plumbing buys nothing because the OS reclaims the sockets anyway. The +session lifetime is the lifetime of the Python process. +""" + +from __future__ import annotations + +import asyncio +import ssl +from typing import Optional + +import aiohttp +import certifi + + +# Larger per-host pool than aiohttp's default (=100 total / =0 per host) +# so concurrent gated probes + a download to the same host don't queue. +_CONNECTOR_LIMIT_PER_HOST = 8 + +_session: Optional[aiohttp.ClientSession] = None +_lock = asyncio.Lock() + + +def ssl_context() -> ssl.SSLContext: + """TLS context pinned to certifi's CA bundle. + aiohttp's default context uses the OS trust store, which isn't wired up + on some Python installs (python.org macOS, slim containers) — there TLS + to huggingface.co fails with CERTIFICATE_VERIFY_FAILED. + """ + return ssl.create_default_context(cafile=certifi.where()) + + +async def get_session() -> aiohttp.ClientSession: + """Return the shared session, creating it on first call.""" + global _session + if _session is not None and not _session.closed: + return _session + async with _lock: + if _session is None or _session.closed: + connector = aiohttp.TCPConnector( + limit_per_host=_CONNECTOR_LIMIT_PER_HOST, + ssl=ssl_context(), + ) + _session = aiohttp.ClientSession(connector=connector) + return _session + + +def parse_content_length(value: Optional[str]) -> Optional[int]: + """Parse a byte-count header value, or None if absent/malformed/negative.""" + if not value: + return None + try: + n = int(value) + except ValueError: + return None + return n if n >= 0 else None diff --git a/app/model_downloader/paths.py b/app/model_downloader/paths.py new file mode 100644 index 000000000..142ada8db --- /dev/null +++ b/app/model_downloader/paths.py @@ -0,0 +1,93 @@ +"""Path resolution for model downloads. + +Model identifiers used across the download API are *relative destination +paths* of the form ``/`` (e.g. ``loras/my_lora.safetensors``). +This module turns one of those identifiers into an absolute on-disk path +under one of ComfyUI's registered model folders, while rejecting unknown +folders, path traversal, and other ill-formed inputs. +""" + +import os +import re +from typing import Optional, Tuple + +import folder_paths + + +# Constrain components so a model_id can never escape its target directory. +# - directory: a single path segment of safe chars +# - filename: a single path segment of safe chars, must end with a model extension +_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$") + + +class InvalidModelId(ValueError): + """Raised when a model_id is syntactically invalid or refers to an + unknown model folder.""" + + +def parse_model_id(model_id: str) -> Tuple[str, str]: + """Split ``/`` and validate both components. + + Returns ``(directory, filename)``. Raises ``InvalidModelId`` on + malformed input. Does NOT touch the filesystem. + """ + if not isinstance(model_id, str) or "/" not in model_id: + raise InvalidModelId(f"model_id must be '/', got {model_id!r}") + directory, _, filename = model_id.partition("/") + if "/" in filename or not directory or not filename: + raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}") + if not _SEGMENT_RE.match(directory): + raise InvalidModelId(f"invalid directory segment {directory!r}") + if not _SEGMENT_RE.match(filename): + raise InvalidModelId(f"invalid filename segment {filename!r}") + if directory not in folder_paths.folder_names_and_paths: + raise InvalidModelId(f"unknown model folder {directory!r}") + return directory, filename + + +def resolve_existing(model_id: str) -> Optional[str]: + """Return the absolute path of an installed model, or None if missing. + + Honours ``extra_model_paths.yaml`` transparently via + ``folder_paths.get_full_path``. + """ + directory, filename = parse_model_id(model_id) + return folder_paths.get_full_path(directory, filename) + + +def resolve_destination(model_id: str) -> Tuple[str, str]: + """Return ``(final_path, tmp_path)`` for a download. + + Downloads land at the first registered path for the model's directory + (the "primary" location). The ``.tmp`` sibling is used as the write + target and atomically renamed on success. + """ + directory, filename = parse_model_id(model_id) + roots = folder_paths.get_folder_paths(directory) + if not roots: + raise InvalidModelId(f"no on-disk path registered for folder {directory!r}") + root = roots[0] + final_path = os.path.join(root, filename) + tmp_path = final_path + ".tmp" + return final_path, tmp_path + + +def iter_all_tmp_paths(): + """Yield every ``*.tmp`` file under every registered model folder. + + Used at startup to sweep orphans left by crashed/restarted downloads. + """ + seen_roots: set[str] = set() + for directory in folder_paths.folder_names_and_paths.keys(): + for root in folder_paths.get_folder_paths(directory): + if root in seen_roots or not os.path.isdir(root): + continue + seen_roots.add(root) + try: + for entry in os.scandir(root): + if entry.is_file() and entry.name.endswith(".tmp"): + yield entry.path + except OSError: + # Folder might be unreadable / missing on certain mounts — + # not fatal, just skip it. + continue