diff --git a/app/model_downloader/allowlist.py b/app/model_downloader/allowlist.py new file mode 100644 index 000000000..ea5accc76 --- /dev/null +++ b/app/model_downloader/allowlist.py @@ -0,0 +1,46 @@ +"""URL allowlist for server-side model fetches. + +Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows +agree on which URLs are eligible for download. Server-side allowlisting is +the primary SSRF defense for this subsystem — workflow JSON is untrusted +input (anyone can hand-craft one), so we never let the server fetch URLs +outside this list. +""" + +from urllib.parse import urlparse + +# Frontend parity: ``missingModelDownload-*.js`` exports the same triple as +# ``i = [...]`` (Civitai / HuggingFace / localhost). +_ALLOWED_URL_PREFIXES = ( + "https://huggingface.co/", + "https://civitai.com/", + "http://localhost:", + "http://127.0.0.1:", +) + +# Frontend parity: same set as ``a = [...]`` in the bundle. +_ALLOWED_MODEL_EXTENSIONS = ( + ".safetensors", + ".sft", + ".ckpt", + ".pth", + ".pt", +) + + +def is_url_allowed(url: str) -> bool: + """Check whether ``url`` is permitted as a server-side download source. + + Returns True only when both: + - the URL starts with one of the allowed prefixes, AND + - the URL's final path segment ends with a known model extension. + + Both checks are required to keep arbitrary HTML / API endpoints on + allowlisted hosts (e.g. ``https://huggingface.co/api/...``) off the table. + """ + if not isinstance(url, str) or not url: + return False + if not any(url.startswith(p) for p in _ALLOWED_URL_PREFIXES): + return False + path = urlparse(url).path + return any(path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS) diff --git a/app/model_downloader/api/routes.py b/app/model_downloader/api/routes.py new file mode 100644 index 000000000..921065540 --- /dev/null +++ b/app/model_downloader/api/routes.py @@ -0,0 +1,332 @@ +"""Aiohttp routes for the server-side model download subsystem. + +Endpoint surface (all under ``/api/``, all kebab-case): + + - ``POST /api/models-availability-status`` — bulk status + metadata query. + - ``POST /api/download-models`` — start a batch of downloads. + - ``POST /api/cancel-model-download-session`` — cancel a single in-flight one. + - ``GET /api/hf-auth-token-status`` — current HF login state. + - ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow. + - ``POST /api/hf-auth-logout`` — drop the stored HF token. + +The contract is intentionally narrow: only model_ids of the form +``/`` (validated via ``app.model_downloader.paths``) +are accepted, and only URLs on the same allowlist the frontend already +uses (HuggingFace, Civitai, localhost) can be fetched. Both are required +to keep the server out of the SSRF business for this feature. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any, Optional + +from aiohttp import web +from pydantic import BaseModel, ValidationError + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.download_server import ( + DOWNLOAD_SERVER, + DownloadSession, +) +from app.model_downloader.downloader import schedule_batch +from app.model_downloader.gated_detection import probe_url +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible +from app.model_downloader.hf_auth.oauth import ( + OAuthInProgressError, + start_login_flow, +) +from app.model_downloader.paths import ( + InvalidModelId, + parse_model_id, + resolve_existing, +) +from app.model_downloader.api import schemas_in, schemas_out + +ROUTES = web.RouteTableDef() + + +def register_routes(app: web.Application) -> None: + """Wire the model-downloader routes into the running aiohttp app. + + Called once from ``server.py`` during ``PromptServer`` startup. + """ + app.add_routes(ROUTES) + + +# ----- response helpers (same envelope as app/assets/api/routes.py) ----- + + +def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _validation_error(code: str, ve: ValidationError) -> web.Response: + return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())}) + + +def _ok(payload: BaseModel, status: int = 200) -> web.Response: + return web.json_response( + payload.model_dump(mode="json", exclude_none=False), + status=status, + ) + + +async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any: + """Parse a JSON body into a pydantic model or raise a 400 response.""" + try: + raw = await request.json() + except json.JSONDecodeError: + return _error(400, "INVALID_JSON", "Request body must be valid JSON.") + try: + return model.model_validate(raw) + except ValidationError as ve: + return _validation_error("INVALID_BODY", ve) + + +# ----- 1. availability status (unified: state + metadata per id) ----- + + +@ROUTES.post("/api/models-availability-status") +async def models_availability_status(request: web.Request) -> web.Response: + """Return per-id ``{state, progress, file_size, is_hf_downloadable}``. + + State (``available`` / ``missing`` / ``downloading``) is cheap to + recompute per call. ``file_size`` and ``is_gated`` are cached + server-side per URL. ``is_hf_downloadable`` is recomputed every + call from the current token state — that's what makes login + license + acceptance show up in the UI within one poll cycle without any + frontend cache plumbing. + """ + parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest) + if isinstance(parsed, web.Response): + return parsed + + items = list(parsed.models.items()) + + # Run all probes concurrently; each is internally cached per URL. + probes = await asyncio.gather(*(probe_url(url) for _, url in items)) + + response_models: dict[str, schemas_out.ModelStatusEntry] = {} + for (model_id, _url), probe in zip(items, probes): + try: + parse_model_id(model_id) + except InvalidModelId: + # Ill-formed identifier: report as missing without 400-ing the + # whole batch — the workflow author probably typo'd. + response_models[model_id] = schemas_out.ModelStatusEntry( + state="missing", + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + continue + + active = DOWNLOAD_SERVER.get(model_id) + if active is not None: + response_models[model_id] = schemas_out.ModelStatusEntry( + state="downloading", + progress=schemas_out.DownloadProgress( + bytes_downloaded=active.bytes_downloaded, + total_bytes=active.total_bytes, + progress=active.progress, + ), + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + continue + + state: schemas_out.ModelState = ( + "available" if resolve_existing(model_id) is not None else "missing" + ) + response_models[model_id] = schemas_out.ModelStatusEntry( + state=state, + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + + return _ok(schemas_out.AvailabilityStatusResponse( + models=response_models, + hf_auth=schemas_out.HfAuthStatus( + token_available=HF_AUTH_STORE.has_token(), + eligible=is_hf_auth_eligible(), + ), + )) + + +# ----- 2. start downloads ----- + + +@ROUTES.post("/api/download-models") +async def download_models(request: web.Request) -> web.Response: + parsed = await _parse_body(request, schemas_in.DownloadModelsRequest) + if isinstance(parsed, web.Response): + return parsed + + if not parsed.models: + return _error(400, "EMPTY_REQUEST", "No models supplied.") + + # ----- precondition pass: validate everything BEFORE registering anything ----- + # Atomic semantics: if any model fails any precondition (invalid id, + # not allow-listed URL, already on disk, already downloading, or gated), + # the entire request fails and no state is changed. + requested = list(parsed.models.items()) + + for model_id, url in requested: + try: + parse_model_id(model_id) + except InvalidModelId as e: + return _error(400, "INVALID_MODEL_ID", str(e), + {"model_id": model_id}) + + if not is_url_allowed(url): + return _error( + 400, "URL_NOT_ALLOWED", + "Server-side downloads only accept HuggingFace, Civitai, " + "or localhost URLs ending in a known model extension.", + {"model_id": model_id, "url": url}, + ) + + if resolve_existing(model_id) is not None: + return _error(409, "ALREADY_AVAILABLE", + f"Model already exists on disk: {model_id}", + {"model_id": model_id}) + + if DOWNLOAD_SERVER.is_downloading(model_id): + return _error(409, "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress.", + {"model_id": model_id}) + + # Reachability check last — it's the only one that talks to the + # network. Concurrent probes. For HF URLs ``is_hf_downloadable`` + # reflects current token access; for non-HF URLs it's None, and we + # treat that as "no info, proceed". + probes = await asyncio.gather(*(probe_url(url) for _, url in requested)) + for (model_id, url), probe in zip(requested, probes): + if probe.is_hf_downloadable is False: + return _error( + 400, "MODEL_NOT_DOWNLOADABLE", + f"Model {model_id} is gated on HuggingFace and the current " + f"server token (if any) does not grant access.", + {"model_id": model_id, "url": url}, + ) + + # ----- registration pass: try_register is atomic per model_id ----- + # Defensive: another request might have raced past our pre-check + # between the loop above and here. try_register handles that. + sessions: list[DownloadSession] = [] + for model_id, url in requested: + session = DOWNLOAD_SERVER.try_register(model_id, url) + if session is None: + # Race: someone else got in. Roll back what we registered. + for s in sessions: + DOWNLOAD_SERVER.cancel(s.model_id) + return _error(409, "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress (race).", + {"model_id": model_id}) + sessions.append(session) + + DOWNLOAD_SERVER.sweep_orphan_tmp_files() + schedule_batch(sessions) + logging.info( + "[model_downloader] scheduled %d downloads: %s", + len(sessions), [s.model_id for s in sessions], + ) + + return _ok(schemas_out.DownloadModelsResponse( + accepted=True, + scheduled=[s.model_id for s in sessions], + ), status=202) + + +# ----- 3. cancel a session ----- + + +@ROUTES.post("/api/cancel-model-download-session") +async def cancel_model_download_session(request: web.Request) -> web.Response: + parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest) + if isinstance(parsed, web.Response): + return parsed + + cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id) + if not cancelled: + return _error(404, "NOT_DOWNLOADING", + f"No active download for {parsed.model_id}.", + {"model_id": parsed.model_id}) + + return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True)) + + +# ----- 4. HuggingFace OAuth status / login start / logout ----- + + +@ROUTES.get("/api/hf-auth-token-status") +async def hf_auth_token_status(request: web.Request) -> web.Response: + """Return whether the server holds a usable HF token + its username. + + Used by the settings UI and (out-of-band) by the frontend on + login completion. ``token_available`` is true even if the cached + access_token is expired — as long as a refresh_token exists, the + user is "logged in" from their perspective. + """ + token_present = HF_AUTH_STORE.has_token() + username: Optional[str] = None + if token_present: + # Resolve the username via whoami. Done in a worker thread because + # huggingface_hub's whoami is synchronous + blocks on a network call. + tok = await HF_AUTH_STORE.get_valid_token() + if tok is not None: + try: + username = await asyncio.to_thread(_whoami_username, tok.access_token) + except Exception as e: + logging.debug("[hf_auth] whoami failed: %s", e) + return _ok(schemas_out.HfAuthTokenStatusResponse( + token_available=token_present, + username=username, + )) + + +def _whoami_username(token: str) -> Optional[str]: + """Sync helper: ask HF for the user name attached to a token.""" + from huggingface_hub import HfApi + info = HfApi().whoami(token=token) + if isinstance(info, dict): + return info.get("name") or info.get("fullname") + return None + + +@ROUTES.post("/api/hf-auth-login-start") +async def hf_auth_login_start(request: web.Request) -> web.Response: + """Begin one OAuth attempt: bind the callback port, return the URL. + + Rejected outright if this deployment isn't eligible (we don't + surface the option on multi-tenant / public-IP installs). + """ + if not is_hf_auth_eligible(): + return _error( + 403, "HF_AUTH_NOT_ELIGIBLE", + "This server is not eligible for interactive HuggingFace login. " + "It must be bound to a loopback address and not running in " + "--multi-user mode.", + ) + try: + url = await start_login_flow() + except OAuthInProgressError: + return _error( + 409, "HF_AUTH_IN_PROGRESS", + "Another HuggingFace login attempt is in progress. Try again " + "after it completes or times out.", + ) + return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url)) + + +@ROUTES.post("/api/hf-auth-logout") +async def hf_auth_logout(request: web.Request) -> web.Response: + """Drop the in-memory + on-disk HF token.""" + HF_AUTH_STORE.clear() + return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True)) diff --git a/app/model_downloader/api/schemas_in.py b/app/model_downloader/api/schemas_in.py new file mode 100644 index 000000000..d792520ea --- /dev/null +++ b/app/model_downloader/api/schemas_in.py @@ -0,0 +1,41 @@ +"""Request schemas for the model-downloader API. + +Each endpoint accepts a small JSON body. Pydantic enforces the shape at +the boundary; route handlers operate only on validated values past that. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class AvailabilityStatusRequest(BaseModel): + """``POST /api/models-availability-status``. + + Sent by the frontend on each poll. Each entry is ``{model_id: url}``; + the URL is the one declared in ``properties.models[i].url`` in the + workflow JSON and lets the server compute per-id metadata + (``file_size`` + ``is_hf_downloadable``) on the same request. + """ + models: dict[str, str] = Field(default_factory=dict) + + +class DownloadModelsRequest(BaseModel): + """``POST /api/download-models``. + + Same shape as the metadata request — the URL for each model_id. + Returns immediately after validation and scheduling. + """ + models: dict[str, str] = Field(default_factory=dict) + + +class CancelDownloadSessionRequest(BaseModel): + """``POST /api/cancel-model-download-session``.""" + model_id: str + + +__all__ = [ + "AvailabilityStatusRequest", + "DownloadModelsRequest", + "CancelDownloadSessionRequest", +] diff --git a/app/model_downloader/api/schemas_out.py b/app/model_downloader/api/schemas_out.py new file mode 100644 index 000000000..5571ecd28 --- /dev/null +++ b/app/model_downloader/api/schemas_out.py @@ -0,0 +1,81 @@ +"""Response schemas for the model-downloader API.""" + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel + + +ModelState = Literal["available", "missing", "downloading"] + + +class DownloadProgress(BaseModel): + """Embedded in a model entry when its state is ``downloading``.""" + bytes_downloaded: int + total_bytes: Optional[int] = None + progress: Optional[float] = None # fraction in [0,1]; null until total known + + +class ModelStatusEntry(BaseModel): + """Everything the UI needs to render one row, in one shot. + + ``state`` reflects what the server has on disk + in-flight; ``file_size`` + and ``is_hf_downloadable`` come from probes (intrinsic; cached). + The HF fields are populated for every poll (cached on the server), + so license-acceptance flips show up within one poll interval without + any frontend cache invalidation. + """ + state: ModelState + progress: Optional[DownloadProgress] = None + file_size: Optional[int] = None + # HF-only: True iff the server can fetch this URL with current auth + # state. False iff gated and lacking access. None for non-HF URLs. + is_hf_downloadable: Optional[bool] = None + + +class HfAuthStatus(BaseModel): + """Snapshot of HF login state, embedded in availability response.""" + token_available: bool + eligible: bool + + +class AvailabilityStatusResponse(BaseModel): + models: dict[str, ModelStatusEntry] + hf_auth: HfAuthStatus + + +class DownloadModelsResponse(BaseModel): + accepted: bool + scheduled: list[str] + + +class CancelDownloadSessionResponse(BaseModel): + cancelled: bool + + +class HfAuthTokenStatusResponse(BaseModel): + token_available: bool + username: Optional[str] = None + + +class HfAuthLoginStartResponse(BaseModel): + authorize_url: str + + +class HfAuthLogoutResponse(BaseModel): + logged_out: bool + + +__all__ = [ + "ModelState", + "DownloadProgress", + "ModelStatusEntry", + "HfAuthStatus", + "AvailabilityStatusResponse", + "DownloadModelsResponse", + "CancelDownloadSessionResponse", + "HfAuthTokenStatusResponse", + "HfAuthLoginStartResponse", + "HfAuthLogoutResponse", +] diff --git a/app/model_downloader/download_server.py b/app/model_downloader/download_server.py new file mode 100644 index 000000000..c323309d8 --- /dev/null +++ b/app/model_downloader/download_server.py @@ -0,0 +1,179 @@ +"""Process-wide registry of in-flight model downloads. + +A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running +server-side model fetch. Designed to be safe with multiple concurrent +clients hitting the API: each model_id has at most one active session, +and the API rejects requests that conflict with in-flight downloads. + +Cancellation is cooperative — the download loop checks ``is_active`` on +its own session between chunks and raises ``DownloadCancelled`` when the +session has been removed from the registry. This avoids the complications +of ``Task.cancel()`` from outside the loop while still giving deterministic +rollback semantics (the worker is responsible for deleting its own +``.tmp`` on the cancel path). +""" + +from __future__ import annotations + +import logging +import os +import threading +from dataclasses import dataclass, field +from typing import Optional + +from app.model_downloader.paths import iter_all_tmp_paths + + +class DownloadCancelled(Exception): + """Raised by the streaming loop when its session has been removed + from the registry (cancellation request) and the worker should roll + back its ``.tmp`` file.""" + + +@dataclass +class DownloadSession: + """One in-flight download. + + ``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first + byte arrives and we know whether the response carries a + ``Content-Length``. ``total_bytes`` mirrors that header when present. + """ + model_id: str + url: str + progress: Optional[float] = None + bytes_downloaded: int = 0 + total_bytes: Optional[int] = None + # Sequence number used solely as identity for the cancellation check — + # so that "cancel + restart" doesn't get confused by stale workers. + epoch: int = field(default_factory=lambda: 0) + + +class DownloadServer: + """Singleton registry of active downloads. + + All mutation goes through this object so concurrent route handlers + see a consistent view. The ``_lock`` is a plain threading lock + because the registry is consulted from both the asyncio event-loop + thread (route handlers) and from any worker coroutines spawned to + perform downloads. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._sessions: dict[str, DownloadSession] = {} + self._epoch_counter = 0 + self._orphan_sweep_done = False + + # ----- lifecycle ----- + + def sweep_orphan_tmp_files(self) -> None: + """Idempotently sweep ``*.tmp`` files left by crashed downloads. + + Deferred off the import path so module load doesn't block on + filesystem I/O against potentially-slow mounts. Each route handler + that might create a new ``.tmp`` runs this exactly once. + """ + with self._lock: + if self._orphan_sweep_done: + return + self._orphan_sweep_done = True + for path in iter_all_tmp_paths(): + try: + os.remove(path) + logging.info("[model_downloader] removed orphan tmp file: %s", path) + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + # ----- queries ----- + + def is_downloading(self, model_id: str) -> bool: + with self._lock: + return model_id in self._sessions + + def get(self, model_id: str) -> Optional[DownloadSession]: + with self._lock: + return self._sessions.get(model_id) + + def snapshot(self) -> dict[str, DownloadSession]: + """Return a shallow copy of the current sessions map.""" + with self._lock: + return dict(self._sessions) + + # ----- mutations ----- + + def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]: + """Atomically register a new session iff none exists for ``model_id``. + + Returns the new session on success, ``None`` if a session is already + in flight. Callers must check the return value — the caller is the + sole owner of the session it gets back. + """ + with self._lock: + if model_id in self._sessions: + return None + self._epoch_counter += 1 + session = DownloadSession( + model_id=model_id, + url=url, + epoch=self._epoch_counter, + ) + self._sessions[model_id] = session + return session + + def update_progress( + self, + session: DownloadSession, + bytes_downloaded: int, + total_bytes: Optional[int], + ) -> None: + """Update progress on a session. No-op if the session has been + removed (cancelled) — caller should check ``is_active`` separately.""" + with self._lock: + current = self._sessions.get(session.model_id) + if current is None or current.epoch != session.epoch: + return + current.bytes_downloaded = bytes_downloaded + current.total_bytes = total_bytes + if total_bytes and total_bytes > 0: + current.progress = min(1.0, bytes_downloaded / total_bytes) + + def is_active(self, session: DownloadSession) -> bool: + """True iff this exact session is still the registered one for + its model_id. False after cancellation, after completion, or if + another session has replaced it.""" + with self._lock: + current = self._sessions.get(session.model_id) + return current is not None and current.epoch == session.epoch + + def finish(self, session: DownloadSession) -> None: + """Remove a completed (or cancelled) session from the registry. + + Safe to call multiple times. Only removes if the epoch matches — + we never accidentally evict a *newer* session for the same model_id. + """ + with self._lock: + current = self._sessions.get(session.model_id) + if current is not None and current.epoch == session.epoch: + del self._sessions[session.model_id] + + def reset_for_tests(self) -> None: + """Clear all sessions and reset the epoch counter. Test-only.""" + with self._lock: + self._sessions.clear() + self._epoch_counter = 0 + + def cancel(self, model_id: str) -> bool: + """Remove the session registered for ``model_id``. + + Returns True if there was an active session to cancel. The worker + will discover the cancellation on its next ``is_active`` check + and roll back its ``.tmp`` file. + """ + with self._lock: + if model_id in self._sessions: + del self._sessions[model_id] + return True + return False + + +DOWNLOAD_SERVER = DownloadServer() diff --git a/app/model_downloader/downloader.py b/app/model_downloader/downloader.py new file mode 100644 index 000000000..091a4a398 --- /dev/null +++ b/app/model_downloader/downloader.py @@ -0,0 +1,205 @@ +"""Streaming download worker with progress reporting and cancellation. + +Each download writes to ``.tmp`` and atomically renames into +place on success. Between chunks the worker checks the registry for +cancellation (via ``DownloadServer.is_active``) and rolls back its +``.tmp`` on cancel or on any error. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + +import aiohttp + +from app.model_downloader.download_server import ( + DOWNLOAD_SERVER, + DownloadCancelled, + DownloadSession, +) +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_url import is_hf_url +from app.model_downloader.http_client import get_session, parse_content_length +from app.model_downloader.paths import resolve_destination + + +CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths. +REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120) + + +async def stream_to_disk(session: DownloadSession) -> str: + """Run a single download to completion or cancellation. + + Returns the final on-disk path on success. Removes the ``.tmp`` and + raises on cancellation or failure. The session is finished + (removed from the registry) exactly once, here — callers do not + need to call ``DOWNLOAD_SERVER.finish`` themselves. + """ + final_path, tmp_path = resolve_destination(session.model_id) + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Wipe any stale .tmp from a previous failed attempt before we start — + # otherwise a partial body could masquerade as our completed download + # when the rename finally happens. + _remove_if_exists(tmp_path) + + bytes_seen = 0 + try: + http = await get_session() + headers = _auth_headers_for(session.url) + logging.info( + "[model_downloader] starting GET %s (auth=%s)", + session.url, "yes" if "Authorization" in headers else "no", + ) + async with http.get( + session.url, + allow_redirects=True, + timeout=REQUEST_TIMEOUT, + headers=headers, + ) as resp: + if resp.status != 200: + # Capture a snippet of the response body so 4xx/5xx aren't + # opaque in the logs — HF returns JSON or HTML with a + # human-readable reason on failures. + body_snippet = await _read_short(resp) + logging.warning( + "[model_downloader] GET %s failed: status=%d final_url=%s body=%s", + session.url, resp.status, str(resp.url), body_snippet, + ) + raise DownloadError( + f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}", + status=resp.status, + ) + + total = parse_content_length(resp.headers.get("Content-Length")) + DOWNLOAD_SERVER.update_progress(session, 0, total) + + with open(tmp_path, "wb") as f: + async for chunk in resp.content.iter_chunked(CHUNK_SIZE): + # Cancellation check between chunks. Cheap and means + # cancellation latency is bounded by one chunk plus + # one ``write()`` — typically well under a second + # even on slow disks. + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled() + f.write(chunk) + bytes_seen += len(chunk) + DOWNLOAD_SERVER.update_progress(session, bytes_seen, total) + + # Final cancellation check before we promote the .tmp to the real + # filename — avoids the awkward case where cancel arrives during + # the very last chunk and we'd otherwise commit anyway. + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled() + + # Atomic rename. os.replace is atomic within the same filesystem, + # which is guaranteed here because tmp lives alongside final_path. + os.replace(tmp_path, final_path) + logging.info( + "[model_downloader] downloaded %s (%d bytes) from %s", + session.model_id, bytes_seen, session.url, + ) + return final_path + + except DownloadCancelled: + logging.info("[model_downloader] cancelled: %s", session.model_id) + _remove_if_exists(tmp_path) + raise + except Exception as e: + logging.warning( + "[model_downloader] failed: %s from %s: %s: %s", + session.model_id, session.url, type(e).__name__, e, + exc_info=True, + ) + _remove_if_exists(tmp_path) + raise + finally: + # In all terminal states (success / cancel / error) drop the + # session from the registry. Idempotent — only removes if we're + # still the live epoch for this model_id. + DOWNLOAD_SERVER.finish(session) + + +class DownloadError(Exception): + """Network / protocol error during a download.""" + + def __init__(self, message: str, status: Optional[int] = None) -> None: + super().__init__(message) + self.status = status + + +async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str: + """Read up to ``limit`` bytes of a response body for logging. + + Used to surface the JSON/HTML reason from an HF non-2xx response in + server logs instead of just the status code. Best-effort: any + error here is swallowed. + """ + try: + raw = await resp.content.read(limit) + return raw.decode("utf-8", errors="replace").strip() + except Exception: + return "" + + +def _auth_headers_for(url: str) -> dict[str, str]: + """Return any auth headers we should add to the GET for ``url``. + + For HuggingFace URLs we inject the user's OAuth access token as a + Bearer header — this is HF's documented way to access gated repos + (see ``huggingface_hub.hf_hub_download``'s wire format). For every + other host we send no extra headers; allowlisted public files + don't need them and we don't want to leak tokens to other hosts. + """ + if not is_hf_url(url): + return {} + tok = HF_AUTH_STORE.get_token_sync() + if tok is None or not tok.access_token: + return {} + return {"Authorization": f"Bearer {tok.access_token}"} + + +def _remove_if_exists(path: str) -> None: + try: + os.remove(path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + +async def run_batch_sequential(sessions: list[DownloadSession]) -> None: + """Run a list of sessions one after the other. + + Each session is independent: a failure or cancellation of one does + not abort the rest. Cancellations are observable via the registry + *before* a given download starts, so a session that's been + pre-cancelled (cancel before the worker reached it) just gets skipped. + """ + for session in sessions: + # If the session got cancelled before its turn, skip without + # touching disk. This is what makes the per-request "sequential + # but cancellable" semantic work. + if not DOWNLOAD_SERVER.is_active(session): + DOWNLOAD_SERVER.finish(session) + continue + try: + await stream_to_disk(session) + except DownloadCancelled: + # Already logged + tmp removed inside stream_to_disk. + continue + except Exception: + # stream_to_disk already logged. Continue with the rest of the batch. + continue + + +def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task: + """Kick off ``run_batch_sequential`` on the running event loop. + + Returned task is fire-and-forget; the API handler returns immediately + after scheduling and clients observe progress via the polling endpoints. + """ + return asyncio.create_task(run_batch_sequential(sessions)) diff --git a/app/model_downloader/gated_detection.py b/app/model_downloader/gated_detection.py new file mode 100644 index 000000000..1dc2d3545 --- /dev/null +++ b/app/model_downloader/gated_detection.py @@ -0,0 +1,238 @@ +"""Per-URL probes for the unified availability endpoint. + +Three cached/derived facts per URL: + + - ``is_gated`` intrinsic to the model; cached forever once known. + Determined by ``auth_check(repo_id, token=None)``: + ``GatedRepoError`` → True, success → False. + + - ``is_hf_downloadable`` depends on the *current* token; recomputed every + call. For non-gated URLs this is trivially True + (no HF call needed). For gated URLs we run + ``auth_check`` with the stored token each call. + + - ``file_size`` intrinsic to the file. Cached forever once + determined (including ``None`` on transient + failure — we don't retry). We only attempt the + HEAD when we already know the URL is downloadable + to us; that way a failed-because-gated probe + never lands as a cached ``None``. + +Caches are per-process, in-memory; small, no eviction needed for the +workflow-scale (~tens of URLs). Concurrent calls for the same URL +deduplicate via per-URL ``asyncio.Lock``. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import Optional + +import aiohttp +from huggingface_hub import HfApi +from huggingface_hub.errors import ( + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, +) + +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_url import is_hf_url, repo_id_from_url +from app.model_downloader.http_client import get_session, parse_content_length + + +_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15) + + +@dataclass +class ProbeResult: + file_size: Optional[int] + is_hf_downloadable: Optional[bool] + + +# --- caches -------------------------------------------------------------- # + + +# url → bool. Whether this URL's HF repo gates access. Intrinsic to the +# model — never changes for a given URL. +_is_gated_cache: dict[str, bool] = {} + +# url → Optional[int]. The file's size in bytes, ``None`` if a probe +# was attempted and produced no answer. **Only populated when we knew +# the URL was downloadable to us at probe time** — so gated-without- +# access never lands a ``None`` here that we'd be stuck with after login. +_file_size_cache: dict[str, Optional[int]] = {} + +# Per-URL locks for single-flight probes — when multiple polls arrive +# in the same tick for the same URL, exactly one of them runs the HF +# call and the others wait on the result. +_locks: dict[str, asyncio.Lock] = {} + + +def _lock_for(url: str) -> asyncio.Lock: + lock = _locks.get(url) + if lock is None: + lock = asyncio.Lock() + _locks[url] = lock + return lock + + +def clear_caches_for_tests() -> None: + """Test-only: drop everything.""" + _is_gated_cache.clear() + _file_size_cache.clear() + _locks.clear() + + +# --- public entrypoint --------------------------------------------------- # + + +async def probe_url(url: str) -> ProbeResult: + """Return downloadability + size for one URL, using caches where safe.""" + if not is_hf_url(url): + # Non-HF: ``is_hf_downloadable`` is "not applicable" (None). + # Size we still cache so we don't HEAD on every poll. + size = await _get_or_probe_size(url, token=None) + return ProbeResult(file_size=size, is_hf_downloadable=None) + + repo_id = repo_id_from_url(url) + if repo_id is None: + return ProbeResult(file_size=None, is_hf_downloadable=None) + + # Determine intrinsic gating once. + gated = await _resolve_is_gated(url, repo_id) + if gated is None: + return ProbeResult(file_size=None, is_hf_downloadable=None) + + # Compute current-token downloadability per call. + tok = HF_AUTH_STORE.get_token_sync() + token_str: Optional[str] = tok.access_token if tok else None + if not gated: + is_hf_downloadable: Optional[bool] = True + else: + is_hf_downloadable = await _auth_check_with_token(repo_id, token_str) + + if is_hf_downloadable is True: + size = await _get_or_probe_size(url, token=token_str) + else: + # Skip the HEAD entirely — would 401 and we'd be stuck with + # cached None that survives a later login. + size = None + + return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable) + + +# --- gated/auth probes --------------------------------------------------- # + + +async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]: + """Decide once whether ``repo_id`` is a gated repo.""" + cached = _is_gated_cache.get(url) + if cached is not None: + return cached + + async with _lock_for(url): + cached = _is_gated_cache.get(url) + if cached is not None: + return cached + try: + await asyncio.to_thread(_auth_check_sync, repo_id, None) + _is_gated_cache[url] = False + return False + except GatedRepoError: + _is_gated_cache[url] = True + return True + except RepositoryNotFoundError: + # Repo doesn't exist publicly. Treat as gated — we can't + # serve it without auth, and an authenticated check might + # still succeed if it's a private repo the user can see. + _is_gated_cache[url] = True + return True + except (HfHubHTTPError, Exception) as e: + logging.debug( + "[hf_auth] is_gated probe failed for %s (will retry): %s", + repo_id, e, + ) + return None # don't cache; retry next call + + +async def _auth_check_with_token( + repo_id: str, token: Optional[str] +) -> Optional[bool]: + """Auth-check with the supplied token. True/False/None per outcome.""" + try: + await asyncio.to_thread(_auth_check_sync, repo_id, token) + return True + except GatedRepoError: + return False + except RepositoryNotFoundError: + return False + except HfHubHTTPError as e: + # 401/403 covers org-SSO-required, revoked tokens, and similar — + # all of which mean "can't fetch right now" from the user's POV. + status = getattr(getattr(e, "response", None), "status_code", None) + if status in (401, 403): + return False + logging.debug( + "[hf_auth] auth_check transient failure for %s: %s", repo_id, e, + ) + return None + except Exception as e: + logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e) + return None + + +def _auth_check_sync(repo_id: str, token: Optional[str]) -> None: + """Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``.""" + HfApi().auth_check(repo_id, token=token) + + +# --- size probe ---------------------------------------------------------- # + + +async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]: + """Return the cached size or HEAD the URL once and cache the result.""" + if url in _file_size_cache: + return _file_size_cache[url] + + async with _lock_for(url): + if url in _file_size_cache: + return _file_size_cache[url] + size = await _probe_size_once(url, token=token) + _file_size_cache[url] = size + return size + + +async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]: + """HEAD the URL and return the file size in bytes, or None on failure. + + HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL. + The real file size lives in the non-standard ``X-Linked-Size`` header + on that 302 response (``Content-Length`` is the redirect-body length). + Disabling redirect-follow lets us read either header on the same + response: + + - LFS files: 302 + ``X-Linked-Size`` + - Small/non-LFS files: 200 + ``Content-Length`` + """ + headers = {"Authorization": f"Bearer {token}"} if token else {} + try: + session = await get_session() + async with session.head( + url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers, + ) as resp: + linked = parse_content_length(resp.headers.get("X-Linked-Size")) + if linked is not None: + return linked + if resp.status == 200: + return parse_content_length(resp.headers.get("Content-Length")) + return None + except (aiohttp.ClientError, TimeoutError, OSError): + return None + + +# Backward-compat shim so consumers that still import the old name keep +# building during the refactor; can be removed once routes are updated. +MetadataProbeResult = ProbeResult diff --git a/app/model_downloader/hf_auth/auth_store.py b/app/model_downloader/hf_auth/auth_store.py new file mode 100644 index 000000000..2dcefdfe7 --- /dev/null +++ b/app/model_downloader/hf_auth/auth_store.py @@ -0,0 +1,106 @@ +"""In-memory token cache with lazy disk persistence + refresh. + +Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask +``get_valid_token()``; the store transparently refreshes from disk +on first use, refreshes via the OAuth refresh_token if the cached +access_token is expired, and returns ``None`` if neither path works. + +The refresh path imports ``oauth.refresh_access_token`` lazily to +avoid an import cycle (oauth needs the store to save tokens it +acquires). +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from typing import Optional + +from app.model_downloader.hf_auth.token_store import ( + Token, + delete_token, + load_token, + save_token, +) + + +class HfAuthStore: + def __init__(self) -> None: + self._lock = threading.Lock() + self._token: Optional[Token] = None + self._loaded_from_disk = False + + def _ensure_loaded(self) -> None: + """Read the disk token into memory on first access.""" + if self._loaded_from_disk: + return + with self._lock: + if self._loaded_from_disk: + return + self._token = load_token() + self._loaded_from_disk = True + + def has_token(self) -> bool: + """Cheap check: is there any token in memory? + + Does not attempt refresh; an expired-but-refreshable token still + counts as "logged in" from the user's perspective. + """ + self._ensure_loaded() + return self._token is not None + + def set_token(self, token: Token) -> None: + """Replace the in-memory token and persist to disk.""" + with self._lock: + self._token = token + self._loaded_from_disk = True + save_token(token) + + def clear(self) -> None: + """Forget the token in memory and on disk (logout).""" + with self._lock: + self._token = None + self._loaded_from_disk = True + delete_token() + + def get_token_sync(self) -> Optional[Token]: + """Return the cached token without refreshing. + + Sync callers (e.g. constructing an Authorization header in a + non-async path) use this. They accept an expired token over + ``None``; HF will simply return 401 and the caller can decide + what to do. + """ + self._ensure_loaded() + return self._token + + async def get_valid_token(self) -> Optional[Token]: + """Return a fresh token, refreshing via OAuth if necessary. + + Returns ``None`` if there's no cached token at all, or if the + cached token is expired and refresh failed. Callers should + treat that as "user is not logged in". + """ + self._ensure_loaded() + tok = self._token + if tok is None: + return None + if tok.is_valid(): + return tok + if not tok.refresh_token: + return None + + # Lazy import to avoid the oauth ↔ store import cycle. + from app.model_downloader.hf_auth.oauth import refresh_access_token + + try: + refreshed = await refresh_access_token(tok.refresh_token) + except Exception as e: + logging.warning("[hf_auth] token refresh failed: %s", e) + return None + self.set_token(refreshed) + return refreshed + + +HF_AUTH_STORE = HfAuthStore() diff --git a/app/model_downloader/hf_auth/eligibility.py b/app/model_downloader/hf_auth/eligibility.py new file mode 100644 index 000000000..ad788e4cd --- /dev/null +++ b/app/model_downloader/hf_auth/eligibility.py @@ -0,0 +1,55 @@ +"""Whether this deployment is allowed to do interactive HF OAuth. + +We only let the server hold a HuggingFace token under a strict trust +assumption: this is a *single tenant local* install. Concretely: + + - The server is bound to a loopback address. SSH tunneling / + reverse-proxies can defeat this, but it's the strongest signal + we have without an authentication system. + - ``--multi-user`` is off. A shared token used implicitly by multiple + declared users would be a footgun — one user's gated downloads + would silently authenticate as another. + +Anything else and the frontend hides the HF login UI entirely; gated +models continue to show the "acquire it manually" message. +""" + +from __future__ import annotations + +import ipaddress +import socket + +from comfy.cli_args import args + + +def _is_loopback(host: str | None) -> bool: + """Duplicates ``server.is_loopback`` (small, no shared module yet). + + Resolves a host or IP literal to whether it lives on the loopback + interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for + ``0.0.0.0`` / ``::`` because those are bind-all wildcards, not + loopback. + """ + if host is None: + return False + try: + return ipaddress.ip_address(host).is_loopback + except ValueError: + pass + + loopback = False + for family in (socket.AF_INET, socket.AF_INET6): + try: + r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM) + for _family, _, _, _, sockaddr in r: + if not ipaddress.ip_address(sockaddr[0]).is_loopback: + return loopback + loopback = True + except socket.gaierror: + pass + return loopback + + +def is_hf_auth_eligible() -> bool: + """True iff this deployment may surface the HF OAuth flow.""" + return _is_loopback(args.listen) and not args.multi_user diff --git a/app/model_downloader/hf_auth/oauth.py b/app/model_downloader/hf_auth/oauth.py new file mode 100644 index 000000000..7cff05328 --- /dev/null +++ b/app/model_downloader/hf_auth/oauth.py @@ -0,0 +1,277 @@ +"""OAuth 2.0 PKCE flow against HuggingFace's authorization server. + +Wired so that ``POST /api/hf-auth-login-start`` can: + 1. Generate state + PKCE verifier/challenge in this process. + 2. Spin up a short-lived loopback HTTP server at port 41954 to + receive the redirect callback from HF. + 3. Return the ``authorize_url`` for the frontend to open in a new tab. + +After the user grants consent on huggingface.co, HF redirects to the +local callback URL with ``code`` and ``state``. The callback server +validates ``state`` (CSRF), exchanges the code for tokens via PKCE, +hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts +itself down. + +Before this can be exercised end-to-end a maintainer must register a +HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder +below. See the comment above the constant for the exact steps. +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import logging +import secrets +import threading +import time + +import aiohttp +from aiohttp import web + +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_auth.token_store import Token +from app.model_downloader.http_client import ssl_context + + +# --- HF OAuth app registration -------------------------------------------- # +# NOTE: The OAuth client_id below is a placeholder. Before this feature can be +# exercised end-to-end, a maintainer must register a HuggingFace OAuth app +# under a Comfy-Org-controlled HF account and substitute its client_id here. +# Detailed walkthrough is in docs/server-side-model-downloads-handover.html +# ("HuggingFace OAuth app setup" section). Short version: +# 1. huggingface.co → Settings → Connected Apps → "Create app" +# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and +# ``gated-repos`` (Repository Access). Leave everything else off. +# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback`` +# — must match ``REDIRECT_URI`` below; change both in lockstep if you +# change ``CALLBACK_PORT``. +# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below. +# The client_id is not a secret (it travels through the user's browser in +# plaintext); HF's "Public app" type means there's no client secret to +# manage — PKCE replaces it. +HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID" + +CALLBACK_HOST = "127.0.0.1" +CALLBACK_PORT = 41954 +CALLBACK_PATH = "/api/auth/huggingface/callback" +REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}" + +AUTHORIZE_URL = "https://huggingface.co/oauth/authorize" +TOKEN_URL = "https://huggingface.co/oauth/token" +# Minimal scope set for the feature: +# - openid : required by HF when the app uses OIDC at all +# - profile : lets ``HfApi.whoami(token=...)`` return a username for the +# settings UI; cosmetic but expected +# - gated-repos : grants the token enough to call ``auth_check`` and +# download files from public gated repos the user has +# accepted the license for. The wider ``read-repos`` scope +# would also work (it includes ``gated-repos``) but it +# additionally grants private-repo read access, which we +# don't need and which makes the consent screen scarier +# for the user. +SCOPE = "openid profile gated-repos" + +# Maximum time the callback server stays up waiting for the user to +# complete consent on huggingface.co. Past this, the port closes and +# the user has to click "Log in" again. +CALLBACK_TIMEOUT_SECS = 300 + + +# Process-wide lock so two simultaneous /api/hf-auth-login-start +# requests don't fight over port CALLBACK_PORT. +_OAUTH_LOCK = threading.Lock() + + +class OAuthInProgressError(Exception): + """Another OAuth attempt is already running.""" + + +class OAuthCallbackError(Exception): + """The OAuth callback returned an error (HF denied, port stolen, etc.).""" + + +# --- PKCE primitives ------------------------------------------------------ # + + +def _make_pkce() -> tuple[str, str, str]: + """Return ``(verifier, challenge, state)``. + + Verifier never leaves this process. Challenge and state travel + through the user's browser. State is checked on the callback to + prevent a malicious cross-origin redirect from injecting a token. + """ + verifier = secrets.token_urlsafe(64) + challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) + .rstrip(b"=") + .decode("ascii") + ) + state = secrets.token_urlsafe(32) + return verifier, challenge, state + + +def _build_authorize_url(challenge: str, state: str) -> str: + from urllib.parse import urlencode + + params = { + "client_id": HF_CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "response_type": "code", + "scope": SCOPE, + "state": state, + "code_challenge": challenge, + "code_challenge_method": "S256", + } + return f"{AUTHORIZE_URL}?{urlencode(params)}" + + +# --- Token exchange ------------------------------------------------------- # + + +async def _exchange_code(code: str, verifier: str) -> Token: + """Trade the authorization code for an access+refresh token pair.""" + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": REDIRECT_URI, + "client_id": HF_CLIENT_ID, + "code_verifier": verifier, + } + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp: + resp.raise_for_status() + body = await resp.json() + return Token( + access_token=body["access_token"], + refresh_token=body.get("refresh_token"), + expires_at=time.time() + float(body.get("expires_in", 3600)), + scope=body.get("scope", SCOPE), + ) + + +async def refresh_access_token(refresh_token: str) -> Token: + """Trade a refresh_token for a new access (+ possibly refresh) token.""" + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": HF_CLIENT_ID, + } + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp: + resp.raise_for_status() + body = await resp.json() + return Token( + access_token=body["access_token"], + # If HF doesn't rotate refresh tokens, keep using the existing one. + refresh_token=body.get("refresh_token", refresh_token), + expires_at=time.time() + float(body.get("expires_in", 3600)), + scope=body.get("scope", SCOPE), + ) + + +# --- Callback server ------------------------------------------------------ # + + +async def start_login_flow() -> str: + """Begin one OAuth attempt: spawn the callback server, return the URL. + + Returns the URL the frontend should open in a new tab. Raises + ``OAuthInProgressError`` if another attempt is already running. + The callback server runs in the background until the user + completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses; + either way the lock + port are released afterward. + """ + if not _OAUTH_LOCK.acquire(blocking=False): + raise OAuthInProgressError() + + verifier, challenge, state = _make_pkce() + authorize_url = _build_authorize_url(challenge, state) + + # Fire the callback server on the running loop and return. + asyncio.create_task(_run_callback_server(verifier, state)) + return authorize_url + + +async def _run_callback_server(verifier: str, expected_state: str) -> None: + """Listen for HF's redirect once, capture the token, then shut down.""" + received: asyncio.Future[Token] = asyncio.get_event_loop().create_future() + + async def handler(request: web.Request) -> web.Response: + try: + if request.query.get("state") != expected_state: + return web.Response(status=400, text="state mismatch") + err = request.query.get("error") + if err: + received.set_exception(OAuthCallbackError(f"HF returned: {err}")) + return web.Response(status=400, text=f"OAuth error: {err}") + code = request.query.get("code") + if not code: + return web.Response(status=400, text="missing code") + tok = await _exchange_code(code, verifier) + if not received.done(): + received.set_result(tok) + return web.Response( + content_type="text/html", + text=( + "" + "

HuggingFace login successful

" + "

You can close this tab and return to ComfyUI.

" + "" + ), + ) + except Exception as exc: + if not received.done(): + received.set_exception(exc) + return web.Response(status=500, text=str(exc)) + + app = web.Application() + app.router.add_get(CALLBACK_PATH, handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True) + try: + await site.start() + except OSError as e: + # Port already in use (or some other socket-bind failure). Release + # the lock so a future attempt has a chance to succeed. + logging.warning("[hf_auth] could not bind callback port: %s", e) + _OAUTH_LOCK.release() + return + + try: + token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS) + except asyncio.TimeoutError: + logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS) + return + except OAuthCallbackError as e: + logging.warning("[hf_auth] OAuth callback error: %s", e) + return + except Exception as e: + logging.warning("[hf_auth] unexpected OAuth failure: %s", e) + return + else: + HF_AUTH_STORE.set_token(token) + logging.info("[hf_auth] OAuth login complete") + finally: + await runner.cleanup() + if _OAUTH_LOCK.locked(): + _OAUTH_LOCK.release() + + +def is_login_in_progress() -> bool: + """True iff a callback server is currently bound + waiting.""" + return _OAUTH_LOCK.locked() + + +# Re-export for callers that only want the URL builder (e.g. tests). +__all__ = [ + "start_login_flow", + "refresh_access_token", + "is_login_in_progress", + "OAuthInProgressError", + "CALLBACK_TIMEOUT_SECS", +] diff --git a/app/model_downloader/hf_auth/token_store.py b/app/model_downloader/hf_auth/token_store.py new file mode 100644 index 000000000..28e0288f5 --- /dev/null +++ b/app/model_downloader/hf_auth/token_store.py @@ -0,0 +1,89 @@ +"""On-disk persistence for the HuggingFace OAuth token. + +The token shape mirrors what HF returns on the token exchange: an +``access_token``, an optional ``refresh_token``, the absolute epoch at +which the access token expires, and the granted scope. We persist +this so logging in once survives ComfyUI restarts; the file is mode +``0600`` so only the OS user can read it. +""" + +from __future__ import annotations + +import json +import logging +import os +import stat +import time +from dataclasses import asdict, dataclass +from typing import Optional + +import folder_paths + + +# Treat a token as expired this many seconds before its server-reported +# ``expires_at`` so we don't try to use a token mid-request only for it +# to flip stale between auth_check and the actual GET. +EXPIRY_BUFFER_SECS = 60 + +TOKEN_FILENAME = "hf_auth_token.json" + + +@dataclass +class Token: + """One OAuth token + the metadata we need to use it.""" + access_token: str + refresh_token: Optional[str] + expires_at: float # absolute epoch seconds + scope: str = "" + + def is_valid(self) -> bool: + """True iff we can use this token right now.""" + return ( + bool(self.access_token) + and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS) + ) + + +def _token_path() -> str: + base = folder_paths.get_user_directory() + return os.path.join(base, TOKEN_FILENAME) + + +def load_token() -> Optional[Token]: + """Read the persisted token, returning ``None`` if absent or corrupt.""" + path = _token_path() + if not os.path.exists(path): + return None + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return Token(**data) + except (OSError, json.JSONDecodeError, TypeError) as e: + logging.warning("[hf_auth] could not load token at %s: %s", path, e) + return None + + +def save_token(token: Token) -> None: + """Atomically write the token with 0600 permissions.""" + path = _token_path() + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = path + ".tmp" + with open(tmp, "w", encoding="utf-8") as f: + json.dump(asdict(token), f) + os.replace(tmp, path) + try: + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) + except OSError as e: + # On Windows / weird filesystems chmod may be a no-op; not fatal. + logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e) + + +def delete_token() -> None: + """Remove the persisted token; no-op if it doesn't exist.""" + path = _token_path() + try: + os.remove(path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning("[hf_auth] could not remove token at %s: %s", path, e) diff --git a/app/model_downloader/hf_url.py b/app/model_downloader/hf_url.py new file mode 100644 index 000000000..c305d5ac0 --- /dev/null +++ b/app/model_downloader/hf_url.py @@ -0,0 +1,41 @@ +"""Parsers for the ``huggingface.co`` URL shape we accept in workflows. + +The download API accepts URLs of the form +``https://huggingface.co///resolve//``. +We need to recover ``/`` (the *repo_id*) from such URLs for +``huggingface_hub`` API calls (notably ``HfApi.auth_check``). +""" + +from __future__ import annotations + +from typing import Optional +from urllib.parse import urlparse + +_HF_HOST = "huggingface.co" + + +def is_hf_url(url: str) -> bool: + """Cheap host check — does this URL point at huggingface.co?""" + try: + return urlparse(url).hostname == _HF_HOST + except ValueError: + return False + + +def repo_id_from_url(url: str) -> Optional[str]: + """Extract ``/`` from an HF model file URL. + + Returns ``None`` if the URL isn't on huggingface.co or doesn't look + like a model-file URL. The expected shape is + ``///resolve//`` — anything else + (datasets, spaces, /tree/, /blob/, …) we treat as out of scope here. + """ + if not is_hf_url(url): + return None + parts = urlparse(url).path.lstrip("/").split("/") + if len(parts) < 4 or parts[2] != "resolve": + return None + org, repo = parts[0], parts[1] + if not org or not repo: + return None + return f"{org}/{repo}" diff --git a/app/model_downloader/http_client.py b/app/model_downloader/http_client.py new file mode 100644 index 000000000..4c2b81dc6 --- /dev/null +++ b/app/model_downloader/http_client.py @@ -0,0 +1,63 @@ +"""Lazy module-level aiohttp ClientSession. + +A single shared session means TLS handshakes are reused across HEAD probes +and the subsequent GETs to the same host (HuggingFace is the dominant +case), which is a noticeable speedup on cold connections. + +We deliberately don't close the session at process exit — aiohttp's +warning about unclosed sessions is benign at shutdown, and adding atexit +plumbing buys nothing because the OS reclaims the sockets anyway. The +session lifetime is the lifetime of the Python process. +""" + +from __future__ import annotations + +import asyncio +import ssl +from typing import Optional + +import aiohttp +import certifi + + +# Larger per-host pool than aiohttp's default (=100 total / =0 per host) +# so concurrent gated probes + a download to the same host don't queue. +_CONNECTOR_LIMIT_PER_HOST = 8 + +_session: Optional[aiohttp.ClientSession] = None +_lock = asyncio.Lock() + + +def ssl_context() -> ssl.SSLContext: + """TLS context pinned to certifi's CA bundle. + aiohttp's default context uses the OS trust store, which isn't wired up + on some Python installs (python.org macOS, slim containers) — there TLS + to huggingface.co fails with CERTIFICATE_VERIFY_FAILED. + """ + return ssl.create_default_context(cafile=certifi.where()) + + +async def get_session() -> aiohttp.ClientSession: + """Return the shared session, creating it on first call.""" + global _session + if _session is not None and not _session.closed: + return _session + async with _lock: + if _session is None or _session.closed: + connector = aiohttp.TCPConnector( + limit_per_host=_CONNECTOR_LIMIT_PER_HOST, + ssl=ssl_context(), + ) + _session = aiohttp.ClientSession(connector=connector) + return _session + + +def parse_content_length(value: Optional[str]) -> Optional[int]: + """Parse a byte-count header value, or None if absent/malformed/negative.""" + if not value: + return None + try: + n = int(value) + except ValueError: + return None + return n if n >= 0 else None diff --git a/app/model_downloader/paths.py b/app/model_downloader/paths.py new file mode 100644 index 000000000..142ada8db --- /dev/null +++ b/app/model_downloader/paths.py @@ -0,0 +1,93 @@ +"""Path resolution for model downloads. + +Model identifiers used across the download API are *relative destination +paths* of the form ``/`` (e.g. ``loras/my_lora.safetensors``). +This module turns one of those identifiers into an absolute on-disk path +under one of ComfyUI's registered model folders, while rejecting unknown +folders, path traversal, and other ill-formed inputs. +""" + +import os +import re +from typing import Optional, Tuple + +import folder_paths + + +# Constrain components so a model_id can never escape its target directory. +# - directory: a single path segment of safe chars +# - filename: a single path segment of safe chars, must end with a model extension +_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$") + + +class InvalidModelId(ValueError): + """Raised when a model_id is syntactically invalid or refers to an + unknown model folder.""" + + +def parse_model_id(model_id: str) -> Tuple[str, str]: + """Split ``/`` and validate both components. + + Returns ``(directory, filename)``. Raises ``InvalidModelId`` on + malformed input. Does NOT touch the filesystem. + """ + if not isinstance(model_id, str) or "/" not in model_id: + raise InvalidModelId(f"model_id must be '/', got {model_id!r}") + directory, _, filename = model_id.partition("/") + if "/" in filename or not directory or not filename: + raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}") + if not _SEGMENT_RE.match(directory): + raise InvalidModelId(f"invalid directory segment {directory!r}") + if not _SEGMENT_RE.match(filename): + raise InvalidModelId(f"invalid filename segment {filename!r}") + if directory not in folder_paths.folder_names_and_paths: + raise InvalidModelId(f"unknown model folder {directory!r}") + return directory, filename + + +def resolve_existing(model_id: str) -> Optional[str]: + """Return the absolute path of an installed model, or None if missing. + + Honours ``extra_model_paths.yaml`` transparently via + ``folder_paths.get_full_path``. + """ + directory, filename = parse_model_id(model_id) + return folder_paths.get_full_path(directory, filename) + + +def resolve_destination(model_id: str) -> Tuple[str, str]: + """Return ``(final_path, tmp_path)`` for a download. + + Downloads land at the first registered path for the model's directory + (the "primary" location). The ``.tmp`` sibling is used as the write + target and atomically renamed on success. + """ + directory, filename = parse_model_id(model_id) + roots = folder_paths.get_folder_paths(directory) + if not roots: + raise InvalidModelId(f"no on-disk path registered for folder {directory!r}") + root = roots[0] + final_path = os.path.join(root, filename) + tmp_path = final_path + ".tmp" + return final_path, tmp_path + + +def iter_all_tmp_paths(): + """Yield every ``*.tmp`` file under every registered model folder. + + Used at startup to sweep orphans left by crashed/restarted downloads. + """ + seen_roots: set[str] = set() + for directory in folder_paths.folder_names_and_paths.keys(): + for root in folder_paths.get_folder_paths(directory): + if root in seen_roots or not os.path.isdir(root): + continue + seen_roots.add(root) + try: + for entry in os.scandir(root): + if entry.is_file() and entry.name.endswith(".tmp"): + yield entry.path + except OSError: + # Folder might be unreadable / missing on certain mounts — + # not fatal, just skip it. + continue diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 0f30608a9..92d013386 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -98,12 +98,24 @@ def _parse_cli_feature_flags() -> dict[str, Any]: # Default server capabilities +def _hf_auth_eligible_at_startup() -> bool: + """Snapshot eligibility once at feature-flag init time. + + Imports lazily because the flags module loads very early in the + server boot sequence — earlier than the model_downloader package. + """ + from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible + return is_hf_auth_eligible() + + _CORE_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, "node_replacements": True, "assets": args.enable_assets, + "server_side_model_downloads": True, + "hf_auth_eligible": _hf_auth_eligible_at_startup(), } # CLI-provided flags cannot overwrite core flags diff --git a/docs/server-side-model-downloads-handover.html b/docs/server-side-model-downloads-handover.html new file mode 100644 index 000000000..663d6076a --- /dev/null +++ b/docs/server-side-model-downloads-handover.html @@ -0,0 +1,1273 @@ + + + + +Server-Side Model Downloads — Handover + + + + +
+

Server-Side Model Downloads

+
Handover document · branch feat/server-side-model-downloads
+
+ +
+

⚠ Action required before this feature can run end-to-end

+

+ HF_CLIENT_ID in app/model_downloader/hf_auth/oauth.py is a + placeholder string ("REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"). + HuggingFace will reject the authorize redirect until a real app is registered under a + Comfy-Org-controlled HF account and the constant is replaced. +

+

+ Detailed walkthrough is in + §11 — HuggingFace OAuth app setup at the bottom of this doc; + it lists each field and which boxes to tick. Until the placeholder is replaced, the + backend is otherwise fully functional (state polling, public downloads, gated detection + all work) — only the login flow itself fails. +

+
+ + + + + +

1. Overview & scope

+ +

+ ComfyUI workflows declare model dependencies inline via properties.models + entries on loader nodes — each one carries a filename, a directory (e.g. loras, + checkpoints), and a URL to fetch the file from. Until this feature, when a + workflow loaded with a missing model, the frontend offered the user a download button that + triggered a plain browser download via a synthesized <a download> click. + Files landed in the user's Downloads folder; users then had to manually move them + into ComfyUI/models/<directory>/. Gated HuggingFace models couldn't be + downloaded at all without manual huggingface-cli login + hf_hub_download + out-of-band. +

+ +

This change moves the fetch to the server, lands files in the correct on-disk location, and adds +authenticated HuggingFace support so gated models can be downloaded after a one-click OAuth flow.

+ +
+
Scope
+
    +
  • Server-side downloads with progress + cancellation, atomic file placement.
  • +
  • Gated-model detection (HF auth_check) with appropriate UI states.
  • +
  • HuggingFace OAuth PKCE flow with persisted token; per-process single-token model.
  • +
  • Single-tenant local trust model only — the feature gates itself off + on multi-user or non-loopback deployments because there's no real authentication layer + in core ComfyUI to map users to their own tokens.
  • +
+
+ +
+
Out of scope
+ Per-user HF tokens, real authentication, multi-tenant isolation. These would require + building a user-identity layer in core ComfyUI (sessions, cookies, login). The feature + deliberately disables itself rather than ship a half-measure. +
+ + + +

2. Architecture at a glance

+ +
+ ┌──────────────────────────────────────────────────────┐ + │ ComfyUI_frontend (Vue 3 + Pinia + TypeScript) │ + │ - MissingModelCardServerSide.vue │ + │ - HfAuthSettingsPanel.vue │ + │ - useServerSideDownloadsStore (Pinia) │ + │ - serverDownloadsApi.ts (API client) │ + └──────────────────────────┬───────────────────────────┘ + │ HTTP JSON, kebab-case + │ 1 Hz poll when card visible + ▼ + ┌──────────────────────────────────────────────────────┐ + │ ComfyUI (Python aiohttp) │ + │ app/model_downloader/ │ + │ ├─ api/routes.py ◄── 6 endpoints │ + │ ├─ download_server.py ◄── singleton registry │ + │ ├─ downloader.py ◄── streaming worker │ + │ ├─ gated_detection.py ◄── probe + caches │ + │ ├─ allowlist.py ◄── SSRF allowlist │ + │ ├─ paths.py ◄── model_id ↔ disk │ + │ └─ hf_auth/ │ + │ ├─ oauth.py ◄── PKCE + callback srv │ + │ ├─ auth_store.py ◄── token singleton │ + │ ├─ token_store.py ◄── disk I/O │ + │ └─ eligibility.py ◄── loopback gate │ + └──────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────┐ + │ Filesystem │ + │ models/<dir>/... │ + │ user/hf_auth_token │ + └─────────────────────┘ +
+ +

Key idea: every concern lives in app/model_downloader/ as a self-contained +subsystem. Wiring into the rest of ComfyUI is two lines in server.py +(register_routes(self.app)) and one feature-flag entry in +comfy_api/feature_flags.py.

+ + + +

3. Download mechanism

+ +

3.1 The singleton

+ +

+ DOWNLOAD_SERVER (in app/model_downloader/download_server.py) is the + process-wide registry of in-flight downloads. It exists so that: +

+
    +
  • Multiple polling tabs see a single coherent view of "what's downloading right now."
  • +
  • At most one download per model_id can run at a time — preventing two + concurrent writers to the same destination path.
  • +
  • Cancellation is just removal from the registry; the worker discovers cancellation + on its next chunk-boundary check.
  • +
+ +

3.2 DownloadSession state

+ +
@dataclass
+class DownloadSession:
+    model_id: str             # e.g. "loras/my_lora.safetensors"
+    url: str                  # the URL we're fetching from
+    progress: Optional[float] # fraction in [0,1]; None until total known
+    bytes_downloaded: int
+    total_bytes: Optional[int]
+    epoch: int                # see "atomicity" below
+ +

+ The registry is a plain dict[str, DownloadSession] guarded by a + threading.Lock (callable from both the asyncio event-loop thread and + the download-worker tasks). +

+ +

3.3 Download lifecycle

+ +
+ POST /api/download-models + │ + ▼ + ┌─── precondition gate (atomic) ────────────────────────┐ + │ • parse_model_id → valid path, known dir │ + │ • is_url_allowed → HF / Civitai / localhost │ + │ • resolve_existing → not already on disk │ + │ • DOWNLOAD_SERVER.is_downloading? → not in flight │ + │ • probe_url → not gated-without-access │ + │ Any failure → 400/409, NOTHING is registered. │ + └────────────────────────────────────────────────────────┘ + │ + ▼ + try_register(model_id, url) ◄── new epoch number assigned + │ + ▼ + schedule_batch(sessions) ◄── async task started, route returns 202 + │ + ▼ + stream_to_disk(session): + • GET url with Authorization (if HF + token stored) + • aiohttp .content.iter_chunked(64 KiB) + • write to <final_path>.tmp + • between each chunk: + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled + • update_progress() after each chunk + │ + ▼ + os.replace(tmp_path, final_path) ◄── atomic rename + │ + ▼ + DOWNLOAD_SERVER.finish(session) +
+ +

3.4 Atomicity

+ +

Three independent atomicity guarantees, each addressing a different race:

+
    +
  1. File atomicity: downloads write to <final>.tmp and + use os.replace for the promotion. A crashed/cancelled download leaves only + the .tmp, never a partial-but-named-correct file that a loader would happily + load and silently produce garbage outputs.
  2. + +
  3. Registry atomicity: try_register holds the lock and inserts + iff no entry exists. A second concurrent request for the same model_id returns + None; the route then 409s and rolls back any sessions it had already registered + in this batch.
  4. + +
  5. Epoch-based cancellation: each DownloadSession carries an + epoch counter assigned on registration. If a user cancels and then + immediately re-triggers the download (same model_id), a new session + with a new epoch is registered. The old worker, still running on the cancelled session, + observes is_active(session) as False (epoch mismatch), rolls back its own + .tmp, and exits without affecting the new session. Prevents the old worker's + late finish() from accidentally evicting the new session.
  6. +
+ +

3.5 Orphan cleanup

+ +

+ When the server restarts mid-download, any .tmp file is by definition orphaned. + DOWNLOAD_SERVER.sweep_orphan_tmp_files() walks every registered model folder + and removes *.tmp files. Idempotent; runs on the first download request rather + than module import to keep the import path I/O-free. +

+ +

3.6 Auth headers (HuggingFace)

+ +

+ When session.url is on huggingface.co and a token is stored, + stream_to_disk attaches Authorization: Bearer <access_token> + to the GET. Non-HF URLs receive no auth header (avoids token leakage to other hosts). + This is HF's documented way to access gated repos with a personal access token — no + reliance on huggingface_hub's download API. +

+ + + +

4. HuggingFace OAuth mechanism

+ +

4.1 Why OAuth at all

+ +

+ Some HF repos are gated — the user has to accept a license / be approved before they can + download. The bearer token from a logged-in HF account passes that gate. Rather than + asking the user to paste a personal access token (security-awful UX), we run a proper + OAuth 2.0 Authorization Code flow with PKCE, identical pattern to what the + huggingface-cli login command does internally. +

+ +

4.2 Where the token lives

+ + + + + + + + + + + + + +
LayerStorageNotes
In memoryHF_AUTH_STORE singleton (auth_store.py)Lazily loaded from disk on first access. Mutations also flushed to disk.
On disk<user_dir>/hf_auth_token.jsonAtomic write via .tmp + os.replace, chmod 0600 + so only the OS user can read it.
+ +

Token shape (mirrors what HF returns from the token endpoint):

+
{
+  "access_token":  "hf_oauth_…",
+  "refresh_token": "…",         // null if not granted
+  "expires_at":    1739895432.0, // absolute epoch seconds
+  "scope":         "openid profile read-repos"
+}
+ +

4.3 Token lifecycle

+ +
+ POST /api/hf-auth-login-start (only when eligible, see §6) + │ + ├─► generate PKCE verifier + challenge + state + ├─► spin up callback server at 127.0.0.1:41954 (port-locked, 5min timeout) + ├─► return { authorize_url } to frontend + │ + ▼ + Frontend: window.open(authorize_url, "_blank") + │ + ▼ + User authorizes on huggingface.co + │ + ▼ + HF redirects to http://127.0.0.1:41954/api/auth/huggingface/callback?code=…&state=… + │ + ├─► validate state == expected_state (CSRF defence) + ├─► exchange code + verifier → POST https://huggingface.co/oauth/token + ├─► HF_AUTH_STORE.set_token(token) (in memory + disk) + ├─► render "Login complete" page in user's tab + └─► tear down callback server, release port + + Frontend polls /api/hf-auth-token-status next tick and sees token_available: true. + + On expiry (during any request that needs the token): + ├─► get_valid_token() detects expires_at < now + 60s + ├─► POST refresh_token to HF token endpoint + ├─► HF_AUTH_STORE.set_token(refreshed) + └─► return refreshed token to caller + + POST /api/hf-auth-logout + ├─► HF_AUTH_STORE.clear() — wipe memory + remove disk file + └─► (does NOT revoke the token on HF's side; user can do that + at huggingface.co/settings/tokens) +
+ +
+
Single token, single process
+ Only one token can be stored at a time. Calling login-start while + already logged in (or with a pending login flow) will either lock-conflict (409) or + overwrite the existing token on success. This is intentional given the + single-tenant scope — see §6. +
+ +

4.4 PKCE + state protection

+ +

+ Standard OAuth 2.0 PKCE (RFC 7636) with the SHA-256 method: +

+
    +
  • Random 64-byte URL-safe verifier never leaves the server process.
  • +
  • SHA-256 hash of the verifier sent as the code_challenge in the + authorize URL.
  • +
  • Random 32-byte state validated on callback; mismatches return 400 and + the token exchange is skipped (CSRF defence).
  • +
+ + + +

5. API reference & frontend usage

+ +

All routes live under /api/, use kebab-case paths, and POST for input-bearing +operations even when they're "read-only" — keeps semantics uniform and avoids URL-length +limits when payloads grow.

+ + +
+
+ POST + /api/models-availability-status + 1 Hz poll +
+
+

Purpose

+

One-stop status endpoint. Returns per-model state (available / missing / downloading) plus + metadata (file size, HF downloadability) plus current HF auth snapshot, all in one shot.

+ +

Request

+
{
+  "models": {
+    "loras/foo.safetensors": "https://huggingface.co/org/repo/resolve/main/foo.safetensors",
+    "checkpoints/bar.safetensors": "https://huggingface.co/.../bar.safetensors"
+  }
+}
+ +

Response

+
{
+  "models": {
+    "loras/foo.safetensors": {
+      "state": "downloading",                     // "available" | "missing" | "downloading"
+      "progress": {
+        "bytes_downloaded": 1024000,
+        "total_bytes":      29145431166,
+        "progress":         0.000035                 // null until total known
+      },
+      "file_size":          29145431166,             // bytes; null if not probed
+      "is_hf_downloadable": true                     // null for non-HF / probe failure
+    },
+    "checkpoints/bar.safetensors": {
+      "state": "missing",
+      "progress": null,
+      "file_size": 1234567890,
+      "is_hf_downloadable": false                    // gated, no access
+    }
+  },
+  "hf_auth": {
+    "token_available": true,
+    "eligible":         true
+  }
+}
+ +

Frontend usage

+

+ Called every 1 second by useServerSideDownloadsStore.refresh() + while the missing-models card is mounted. Timer auto-stops when no row is downloading + and every remaining missing row is gated (no further state changes possible without + a user action). +

+

+ The polling timer re-arms on user actions: clicking Download, clicking HF login, + or a workflow change that re-registers the model list. +

+
+
+ + +
+
+ POST + /api/download-models + 202 on accept +
+
+

Purpose

+

Trigger one or more downloads. Atomic: either every model passes + every precondition (valid id, allowed URL, not on disk, not in flight, not gated-to-us) + and all are scheduled, or none are — the request returns an error and the registry is + left unchanged.

+ +

Request

+
{
+  "models": {
+    "loras/foo.safetensors": "https://huggingface.co/.../foo.safetensors"
+  }
+}
+ +

Response (success)

+
HTTP 202 Accepted
+{
+  "accepted":  true,
+  "scheduled": ["loras/foo.safetensors"]
+}
+ +

Response (error)

+
HTTP 400 / 409
+{
+  "error": {
+    "code":    "MODEL_NOT_DOWNLOADABLE",  // INVALID_MODEL_ID / URL_NOT_ALLOWED /
+                                          // ALREADY_AVAILABLE / ALREADY_DOWNLOADING /
+                                          // MODEL_NOT_DOWNLOADABLE / EMPTY_REQUEST
+    "message": "…human-readable…",
+    "details": { "model_id": "loras/foo.safetensors", "url": "https://…" }
+  }
+}
+ +

Frontend usage

+

Triggered by clicking Download on a row or Download All Available in + the card header. On 202, the store immediately calls refresh() so the + progress bar appears in the same render tick; the regular 1 Hz polling takes over from there.

+
+
+ + +
+
+ POST + /api/cancel-model-download-session +
+
+

Request

+
{ "model_id": "loras/foo.safetensors" }
+

Response

+
{ "cancelled": true }    // or HTTP 404 with NOT_DOWNLOADING if no active session
+

Frontend usage

+

The X button on a downloading row. The store re-polls availability immediately + so the UI flips back to "missing" without waiting for the next tick.

+

Cancellation is cooperative — the worker checks is_active between chunks + (typically <1s latency) and rolls back its own .tmp on the way out.

+
+
+ + +
+
+ GET + /api/hf-auth-token-status +
+
+

Response

+
{
+  "token_available": true,
+  "username":        "ogluzman"   // resolved via HfApi.whoami(); null if token invalid
+}
+

Frontend usage

+

Used by the HuggingFace settings panel on open and after any login/logout + action. The general polling path doesn't need this endpoint — the same boolean is + embedded in /api/models-availability-status under hf_auth.token_available. + Kept separate so the settings panel doesn't have to query the unrelated models endpoint.

+
+
+ + +
+
+ POST + /api/hf-auth-login-start +
+
+

Request

+

Empty body.

+

Response (success)

+
{
+  "authorize_url": "https://huggingface.co/oauth/authorize?client_id=…&state=…&code_challenge=…"
+}
+

Error responses

+
    +
  • 403 HF_AUTH_NOT_ELIGIBLE — deployment fails the loopback / multi-user gate. See §6.
  • +
  • 409 HF_AUTH_IN_PROGRESS — another login attempt holds the callback port.
  • +
+

Side effect

+

Spins up the OAuth callback server on 127.0.0.1:41954 for up to 5 minutes. + See §4 for the full lifecycle.

+

Frontend usage

+

Triggered from the login banner in the missing-models card, or the Log in with HuggingFace + button in the Settings → HuggingFace panel. On 200, the frontend opens the + authorize_url in a new tab via window.open(url, "_blank").

+
+
+ + +
+
+ POST + /api/hf-auth-logout +
+
+

Response

+
{ "logged_out": true }
+

Frontend usage

+

Settings → HuggingFace → Log out button. Idempotent — succeeds even if no + token was held. Note this does not revoke the token on HF's side; the user can do that + at huggingface.co/settings/tokens if they want full revocation.

+
+
+ + + +

6. Loopback eligibility gate

+ +

6.1 The rule

+
# app/model_downloader/hf_auth/eligibility.py
+
+def is_hf_auth_eligible() -> bool:
+    return _is_loopback(args.listen) and not args.multi_user
+ +

HF auth surfaces — both the login flow and the settings panel — appear iff this returns True.

+ +

6.2 Why it exists

+ +

+ Core ComfyUI has no authentication. Any HF token the server holds is implicitly shared + by anyone who can reach the server. In a single-user local install that's fine — the OS + user is the boundary, the loopback bind keeps remote actors out. In any other deployment + it would be a credential-leak by misconfiguration: +

+ +
    +
  • Non-loopback: anyone on the network who can reach the port could trigger + downloads using the operator's HF account.
  • +
  • --multi-user mode: multiple declared users (via the + unauthenticated comfy-user header) would all share one HF token implicitly — + Alice's prompts would silently fetch gated content as Bob.
  • +
+ +

+ Both cases are real credential leakage that the operator probably didn't realize + they were enabling. The gate disables the feature instead of shipping a footgun. +

+ +

6.3 What's gated

+ + + + + + + + + + + + + + + + + + + + + + + +
SurfaceHow the gate is applied
Server feature flag hf_auth_eligibleComputed once at startup, returned by /api/features. Frontend reads it + on init to decide whether to render any HF UI at all.
Login start endpointReturns 403 HF_AUTH_NOT_ELIGIBLE if called when ineligible. Defence in + depth — even if the frontend bug rendered the button, the endpoint refuses.
Settings panel (HfAuthSettingsPanel.vue)Registered in useSettingUI.ts only when + api.serverFeatureFlags['hf_auth_eligible'] is true.
Card login bannerConditional render: only shown when eligible and there's at least one + gated row and no token yet.
Per-row gated UI textThree variants based on (eligible, logged-in) state — see §8.
+ +

6.4 Implementation note

+ +

+ We had to inline a copy of is_loopback in eligibility.py + (rather than importing from server.py) because + comfy_api/feature_flags.py evaluates its registry at module-import time — + earlier than server.py defines the helper. The inlined version is + ~20 lines, mirrors server.is_loopback exactly, and is the kind of thing + worth flagging if anyone ever does a "shared util" cleanup pass. +

+ + + +

7. Probe caching strategy

+ +

The polling endpoint runs probe_url(url) for every model on every tick. To +keep that cheap (HuggingFace round-trip per probe is >100ms), the probe layer caches what's +safe to cache and recomputes what isn't:

+ + + + + + + + + + + + + + + + + + + + + + + +
FieldCached?Why
is_gated (intrinsic — "is this repo gated on HF")✅ Forever, per URLProperty of the model, doesn't depend on the user. Determined by a single + auth_check(repo_id, token=None) on first probe.
file_size✅ Forever, per URL (but only after a successful probe)File size doesn't change. We only attempt the HEAD when is_hf_downloadable + is True — avoids caching None from a 401-because-gated, which would otherwise + survive a later successful login.
is_hf_downloadable❌ Recomputed every callDepends on the current token state. Has to update within one poll cycle after login / + logout / license acceptance. Recomputed via auth_check(repo_id, token=current_token) + — but skipped entirely for URLs known to be non-gated (those are trivially True).
On-disk file existence (state)❌ Per callos.path.isfile is a microsecond syscall; not worth caching, and we need + it fresh so the row flips to "available" the instant a download completes.
+ +

+ Single-flight protection: a per-URL asyncio.Lock dedupes concurrent probes + for the same URL — when many polls land in the same tick, exactly one of them runs the HF + call and the others await the same result. Failures aren't cached (they're transient by + nature; retry next call). +

+ +
+
Why this is enough
+ License acceptance happens out-of-band on huggingface.co. The user clicks our + "repository page" link, accepts the license, returns. The next 1 Hz poll's + auth_check with their token now succeeds → is_hf_downloadable + flips to true → the size HEAD fires on that same call → the row transitions from + gated UI to a Download button with the correct size, all within a second of returning. + No frontend cache invalidation, no focus hooks, no manual refresh. +
+ + + +

8. Frontend ↔ backend separation

+ + + + + + + + + + + + + + + + + + + + + + + +
BackendFrontend
Repocomfyanonymous/ComfyUI (this repo)Comfy-Org/ComfyUI_frontend (separate repo)
Language / stackPython 3.13, aiohttp, pydantic, pytestVue 3, TypeScript, Pinia, Vite, PrimeVue, Tailwind
Release artefactSource-distributed; users pip-install the packageBuilt bundle published as the comfyui-frontend-package pip package; ComfyUI + imports the static files.
This feature's filesapp/model_downloader/**, two-line edit to server.py, one-line + edit to comfy_api/feature_flags.py, additions to openapi.yaml, + two test files under tests-unit/app_test/src/platform/missingModel/serverDownloads/** (new directory), a few-line edit + to MissingModelCard.vue for the feature-flag switch, and a registration edit + in src/platform/settings/composables/useSettingUI.ts
+ +

8.1 Local dev workflow

+ +
# Backend  (one terminal)
+cd ComfyUI
+python main.py --listen 127.0.0.1 --port 8189 --cpu
+
+# Frontend (another terminal)
+cd ComfyUI_frontend
+DEV_SERVER_COMFYUI_URL=http://127.0.0.1:8189 pnpm dev
+# Vite serves at http://localhost:5173 and proxies /api/* to the backend
+ +

Open http://localhost:5173 in a browser — you get the Vite dev server with HMR, +talking to your local backend.

+ +

8.2 Frontend integration points

+ +
    +
  • + Feature-flag gate. MissingModelCard.vue renders the new + MissingModelCardServerSide.vue when isServerSideDownloadsAvailable() + returns true (the server_side_model_downloads server feature flag). Old servers + silently fall through to the legacy in-browser download path. +
  • +
  • + Eligibility-flag gate. The HF settings panel and login banner only render + when hf_auth_eligible is true. Read once at startup. +
  • +
  • + Single store of truth. useServerSideDownloadsStore (Pinia) + holds the entire view of the polling response. Components read; only the store mutates. +
  • +
  • + Three gated-row variants. The per-row gated message changes based on + (eligible, logged-in): +
      +
    • Not eligible: "open the repository page to accept the license, then place + the file in models/<dir>/ manually."
    • +
    • Eligible, not logged in: "log in with HuggingFace above to enable the download."
    • +
    • Eligible, logged in: "visit the repository page to accept the license, then + come back to download" (license acceptance does the rest via the 1 Hz poll).
    • +
    +
  • +
+ + + +

9. Tests & OpenAPI spec

+ +

9.1 Test coverage

+

~70 unit tests in two files under tests-unit/app_test/:

+
    +
  • model_downloader_test.py — allowlist, path validation, registry + lifecycle (including epoch race semantics), orphan .tmp cleanup, + precondition gating on all four model routes, atomic batch behavior.
  • +
  • hf_auth_test.py — token store (save / load / chmod / corruption / + refresh), eligibility under (listen, multi_user) matrix, URL parsing, + probe caching (intrinsic + size + skip-when-not-downloadable), all three HF auth + routes, PKCE primitives + authorize URL shape.
  • +
+
$ pytest tests-unit/app_test/model_downloader_test.py tests-unit/app_test/hf_auth_test.py -q
+71 passed in 0.23s
+ +

9.2 OpenAPI spec

+

+ All six routes are documented in openapi.yaml with request/response schemas. + The spec is hand-maintained — there's no codegen between handler signatures and the YAML. + §10 flags this as a long-term tech-debt item. +

+

+ Lint is enforced in CI via Spectral + (.github/workflows/openapi-lint.yml); local run: +

+
npx -y @stoplight/spectral-cli@6 lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
+ + + +

10. Open follow-ups & gotchas

+ +
+
Placeholder OAuth client_id
+ HF_CLIENT_ID in app/model_downloader/hf_auth/oauth.py is a + placeholder string and must be replaced with a real registered HuggingFace OAuth app's + client_id before the login flow can succeed. Full instructions are at the top of this + document (the yellow "Action required" callout). Until that's done, calling + POST /api/hf-auth-login-start succeeds locally but the resulting + authorize_url will return an error from huggingface.co. +
+ +
+
Org SSO requirement on HuggingFace
+ Some HF orgs (e.g. Lightricks) require SSO authorization of personal access tokens before + byte-level access is granted. The token-based flow we build returns + is_hf_downloadable: false for those repos with a clear log line: + [hf_auth] auth_check forbids …/… (HTTP 403) — treating as gated. + The user has to authorize their token via the org's SSO setup at + https://huggingface.co/organizations/<org>/sso. Not a code bug — a + property of the org's policy. +
+ +
+
No TLS in default ComfyUI
+ ComfyUI supports TLS via --tls-keyfile / --tls-certfile but + doesn't enable it by default. Browsers treat http://localhost as a + secure context, so Secure cookies / HF auth still work without TLS on + loopback. Non-loopback deployments without TLS are correctly excluded by the eligibility + gate, so the lack of default TLS isn't a hole for this feature. +
+ +

10.1 Things deliberately not done

+
    +
  • Per-user HF tokens. Requires a real auth layer in core ComfyUI + (sessions, login, identity). Out of scope; the loopback gate is the substitute.
  • +
  • Hash verification of downloaded files. Some + properties.models[*].hash entries carry a SHA. We don't verify; trust the + source. Easy to add if needed (one method on stream_to_disk).
  • +
  • Resumable downloads. A failed download starts from zero next time. + Range requests + offset tracking would add it.
  • +
  • Codegen for the OpenAPI spec. Spec is hand-maintained and lint-checked, + not derived from handler signatures. Long-term direction is probably pydantic-driven + schema export, but that's a project unto itself.
  • +
+ +

10.2 Convention summary

+
    +
  • Route paths: kebab-case for all new routes + (/api/models-availability-status, etc.). Older endpoints use snake_case; + newer assets endpoints use kebab; we picked kebab to match the newer direction.
  • +
  • Error envelope: + {"error": {"code": "MACHINE_READABLE", "message": "human", "details": {...}}}. + Matches the pattern in app/assets/api/routes.py.
  • +
  • Pydantic at the boundary: request schemas in + schemas_in.py, response schemas in schemas_out.py, + validated via Schema.model_validate(payload) in handlers.
  • +
  • Logging prefix: all logs use [model_downloader] or + [hf_auth] prefixes for grep-ability.
  • +
+ +

10.3 Useful greps

+
# Find every backend file touched by this feature
+ls app/model_downloader app/model_downloader/api app/model_downloader/hf_auth
+
+# Find every place is_loopback is consulted (3 callers)
+grep -rn "is_loopback" --include="*.py" app/ server.py
+
+# Confirm the HF OAuth callback port and redirect URI
+grep -n "CALLBACK_PORT\|REDIRECT_URI" app/model_downloader/hf_auth/oauth.py
+
+# Run the test suite for just this feature
+.venv/bin/python -m pytest tests-unit/app_test/model_downloader_test.py \
+                            tests-unit/app_test/hf_auth_test.py -q
+ + + +

11. HuggingFace OAuth app setup

+ +

+ Step-by-step walkthrough for creating the OAuth app whose client_id goes into + HF_CLIENT_ID. Reflects what the HuggingFace settings UI looked like at the + time this feature was developed; HF occasionally moves things around but the fields + themselves are stable. +

+ +

11.1 Navigation

+
    +
  1. Sign in at huggingface.co + with the Comfy-Org-controlled account that should own the app.
  2. +
  3. Open user settings (avatar menu → Settings).
  4. +
  5. In the left sidebar, click Connected Apps. (Not Access Tokens + — that's for personal access tokens, a different concept.)
  6. +
  7. Click Create app (or similar — the button label has varied).
  8. +
+ +

11.2 Fields to fill in

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FieldValueNotes
Application Namee.g. ComfyUIShown on the user's consent screen and in their Connected Apps list. Keep it + recognisable.
Homepage URLOptional. Leave blank or use https://www.comfy.org.Cosmetic.
LogoOptional.Cosmetic.
Token ExpirationDefault (8 hours) is fine.Our code transparently refreshes via the OAuth refresh-token flow; a shorter expiry + just means refresh happens more often. Don't pick an extremely short one — you'd put + needless load on HF's token endpoint.
Default ScopesSee §11.3 below.Critical — this controls what consent the user sees and what the token can do.
Redirect URLshttp://127.0.0.1:41954/api/auth/huggingface/callback + Must match exactly. If you change CALLBACK_PORT in + oauth.py, change this in lockstep. Multiple redirect URLs can be + registered (one per line) if you need both dev and prod variants later. +
+ +

11.3 Scopes — exactly which boxes to tick

+ +

+ HF groups scopes into sections. The bare minimum for this feature is three + checkboxes total. Leave everything else off. +

+ + + + + + + + + + + + + + + + + + +
SectionScope to checkWhy
User InfoopenidRequired by HF when the app uses OpenID Connect at all (which our PKCE + flow does — it's part of the OAuth2 + OIDC handshake).
User InfoprofileLets HfApi.whoami(token=...) return a username. The Settings panel + shows that username next to the "Logged in" indicator. Strictly cosmetic but + expected by the UI.
Repository Accessgated-repos
"Read public gated repos only"
The key scope. Grants the token enough to (a) call auth_check against + gated repos the user has accepted the license for, and (b) download files from those + repos. Public-only — no private-repo access included, no write permissions.
+ +
+
Do not pick a wider scope
+ read-repos would also work for the feature (it includes + gated-repos plus private-repo read access), but picking it makes the + user's consent screen on huggingface.co look scarier ("this app wants to read your + private repositories"). Users may bail. Stick to gated-repos. +
+ +

11.4 Public app + PKCE

+ +

+ After creation, HF will label the app a Public app and explicitly note: + "No client secret. Use PKCE or device code flow for authentication." This is + expected and correct — we use PKCE (see §4). Do not + click Add client secret; we don't need it and having one without using it would + be a future footgun. +

+ +

11.5 Wire the client_id into the code

+ +

The Credentials section of the new app shows a Client ID in the form of a UUID +(e.g. a8189e14-9246-4f19-bd6a-a307bdcb9276). Copy that value and paste it +verbatim into:

+ +
# app/model_downloader/hf_auth/oauth.py  (around line 49)
+HF_CLIENT_ID = "paste-the-uuid-here"
+ +

That's the only code change required. Restart ComfyUI; POST /api/hf-auth-login-start +should now produce an authorize_url that huggingface.co accepts.

+ +

11.6 Test the round-trip

+ +
    +
  1. Start ComfyUI on loopback: python main.py --listen 127.0.0.1 --port 8189
  2. +
  3. + Confirm eligibility: +
    curl -s http://127.0.0.1:8189/api/features | grep hf_auth_eligible
    +# expect: "hf_auth_eligible": true
    +
  4. +
  5. + Trigger the login flow: +
    curl -s -X POST http://127.0.0.1:8189/api/hf-auth-login-start | python3 -m json.tool
    +# expect: {"authorize_url": "https://huggingface.co/oauth/authorize?client_id=<your-uuid>&..."}
    +
  6. +
  7. Open authorize_url in a browser. The consent screen should display the + Application Name you chose and list the three scopes (openid, profile, + gated-repos). Click Authorize.
  8. +
  9. HF redirects to http://127.0.0.1:41954/api/auth/huggingface/callback?code=...&state=.... + Our local callback server completes the token exchange and renders a "Login complete" page.
  10. +
  11. + Confirm token is held: +
    curl -s http://127.0.0.1:8189/api/hf-auth-token-status | python3 -m json.tool
    +# expect: {"token_available": true, "username": "your-hf-username"}
    +
  12. +
+ +

+ Once that round-trip works, the missing-models card will use the token automatically for + every subsequent gated probe and download. +

+ +

11.7 If you need to change the callback port

+ +

The port 41954 is arbitrary — chosen to be high and unlikely to collide. +If you ever need to change it, three things must move together:

+ +
    +
  • CALLBACK_PORT in app/model_downloader/hf_auth/oauth.py.
  • +
  • The Redirect URL registered on the HuggingFace app (must match exactly, including + port).
  • +
  • The redirect-URI constant in any test fixtures (search for the port number in + tests-unit/app_test/hf_auth_test.py).
  • +
+ +

If they drift out of sync, HF will reject the redirect with a +redirect_uri_mismatch error and the callback never lands.

+ + +
+

+ Generated as a feature handover. Living document — keep it updated as the feature evolves, + or replace with a proper docs site entry once one exists. +

+ + + diff --git a/openapi.yaml b/openapi.yaml index 380e4476e..4b26daa52 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -188,6 +188,49 @@ components: - id - updated_at type: object + AvailabilityStatusRequest: + description: | + Models to query — each entry is `model_id → URL`. The URL lets + the server compute file_size + is_hf_downloadable on the same + request, eliminating the need for a separate metadata endpoint. + properties: + models: + additionalProperties: + type: string + description: model_id → URL declared in the workflow. + type: object + required: + - models + type: object + AvailabilityStatusResponse: + description: Per-model state + metadata + HF auth snapshot. + properties: + hf_auth: + $ref: '#/components/schemas/HfAuthStatus' + models: + additionalProperties: + $ref: '#/components/schemas/ModelStatusEntry' + type: object + required: + - models + - hf_auth + type: object + CancelDownloadSessionRequest: + description: Request to cancel an in-flight download for a given model_id. + properties: + model_id: + type: string + required: + - model_id + type: object + CancelDownloadSessionResponse: + description: Result of a cancellation request. + properties: + cancelled: + type: boolean + required: + - cancelled + type: object CreateWorkflowRequest: description: Request body for creating a new saved workflow. properties: @@ -230,6 +273,51 @@ components: - base_version - workflow_json type: object + DownloadModelsRequest: + description: Map of model_id → URL of files to fetch into the model folders. + properties: + models: + additionalProperties: + type: string + description: model_id → URL of models to download. + type: object + required: + - models + type: object + DownloadModelsResponse: + description: Acknowledgement that downloads have been scheduled. + properties: + accepted: + description: Always true; the request was scheduled. + type: boolean + scheduled: + description: The list of model_ids whose downloads are now in-flight. + items: + type: string + type: array + required: + - accepted + - scheduled + type: object + DownloadProgress: + description: In-flight download progress; embedded in ModelStatusEntry. + properties: + bytes_downloaded: + format: int64 + type: integer + progress: + description: Fraction in [0,1]; null until total_bytes is known. + format: float + nullable: true + type: number + total_bytes: + description: Content-Length when supplied by the source. + format: int64 + nullable: true + type: integer + required: + - bytes_downloaded + type: object ErrorResponse: description: Standard error response with a machine-readable code and human-readable message. properties: @@ -394,6 +482,46 @@ components: - name - info type: object + HfAuthLoginStartResponse: + description: URL the frontend should open in a new tab to complete login. + properties: + authorize_url: + type: string + required: + - authorize_url + type: object + HfAuthLogoutResponse: + description: Result of the logout call (always logged_out = true). + properties: + logged_out: + type: boolean + required: + - logged_out + type: object + HfAuthStatus: + description: Inline snapshot of the server's HuggingFace OAuth state. + properties: + eligible: + description: True iff this deployment can surface interactive HF login. + type: boolean + token_available: + description: True iff a token (possibly expired but refreshable) is stored. + type: boolean + required: + - token_available + - eligible + type: object + HfAuthTokenStatusResponse: + description: Whether the server holds an HF OAuth token + resolved username. + properties: + token_available: + type: boolean + username: + nullable: true + type: string + required: + - token_available + type: object HistoryDetailEntry: description: History entry with full prompt data properties: @@ -798,6 +926,40 @@ components: - name - folders type: object + ModelStatusEntry: + description: Everything the UI needs to render one row of the model. + properties: + file_size: + description: Bytes, when known. Cached server-side per URL. + format: int64 + nullable: true + type: integer + is_hf_downloadable: + description: | + HuggingFace-only signal. True if the server can fetch this + URL with its current auth state (public, or gated-with-access). + False if gated and lacking access. Null for non-HF URLs and + for HF URLs whose probe failed entirely. + nullable: true + type: boolean + progress: + allOf: + - $ref: '#/components/schemas/DownloadProgress' + description: Present when `state == downloading`. + nullable: true + state: + description: | + `available` — file is on disk. + `missing` — not on disk and no download in flight. + `downloading` — server is currently fetching the file. + enum: + - available + - missing + - downloading + type: string + required: + - state + type: object NodeInfo: description: Metadata describing a single ComfyUI node type and its inputs/outputs. properties: @@ -2338,6 +2500,60 @@ paths: summary: Get tag histogram for filtered assets tags: - file + /api/cancel-model-download-session: + post: + description: | + Cancels the download session for the given model_id. The worker + observes the cancellation between chunks, removes its partial `.tmp` + file, and exits without writing the destination path. + operationId: postCancelModelDownloadSession + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CancelDownloadSessionRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/CancelDownloadSessionResponse' + description: Session cancelled. + "404": + description: No active download for that model_id. + summary: Cancel an in-flight server-side model download + tags: + - model + /api/download-models: + post: + description: | + Schedules downloads for every model_id in the request map. Returns + immediately after validation; progress is observed via + `/api/models-availability-status`. Fails atomically if any model + is already on disk, already downloading, gated, or has a URL that + is not on the server's allowlist (HuggingFace, Civitai, localhost). + operationId: postDownloadModels + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/DownloadModelsRequest' + required: true + responses: + "202": + content: + application/json: + schema: + $ref: '#/components/schemas/DownloadModelsResponse' + description: Downloads accepted and scheduled. + "400": + description: One of the requested models is invalid, gated, or has a non-allowed URL. + "409": + description: One of the requested models is already on disk or downloading. + summary: Start a server-side download of one or more models + tags: + - model /api/embeddings: get: description: Returns the list of text-encoder embeddings available on disk. @@ -2639,6 +2855,66 @@ paths: summary: Get a specific subgraph blueprint tags: - workflow + /api/hf-auth-login-start: + post: + description: | + Spawns a short-lived loopback callback server (port 41954) and + returns the URL the frontend should open in a new tab. After the + user grants consent, HF redirects back to the callback URL with + an authorization code; the server exchanges that for a token and + persists it. Subsequent `/api/hf-auth-token-status` calls will + return `token_available: true`. Rejected with 403 if the + deployment is not eligible (not loopback or in --multi-user mode); + 409 if another login attempt is already in progress. + operationId: postHfAuthLoginStart + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HfAuthLoginStartResponse' + description: Login flow started; `authorize_url` is ready. + "403": + description: Deployment is not eligible for interactive HF login. + "409": + description: Another login attempt is already in progress. + summary: Begin a HuggingFace OAuth login flow + tags: + - model + /api/hf-auth-logout: + post: + description: Clears the in-memory cache and removes the on-disk token file. + operationId: postHfAuthLogout + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HfAuthLogoutResponse' + description: Logged out (idempotent — succeeds even if no token was held). + summary: Drop the stored HuggingFace OAuth token + tags: + - model + /api/hf-auth-token-status: + get: + description: | + Returns `token_available: true` when the server has a token + in memory (or on disk) for HuggingFace, irrespective of whether + the access_token is currently fresh — an expired one with a + refresh_token still counts as "logged in" because we'll refresh + transparently on next use. If a username is resolvable via + `HfApi.whoami` we return that too, for the settings UI. + operationId: getHfAuthTokenStatus + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HfAuthTokenStatusResponse' + description: Token status. + summary: Whether the server holds a usable HuggingFace OAuth token + tags: + - model /api/history: post: deprecated: true @@ -3141,6 +3417,44 @@ paths: summary: Cancel multiple jobs tags: - workflow + /api/models-availability-status: + post: + description: | + Given a map of `{model_id: url}` (model_id is + `/`), returns per-id state plus the + metadata the UI needs to render the row: + + - `state` — one of `available` / `missing` / `downloading` + - `progress` — embedded when `state == downloading` + - `file_size` — bytes (when known) + - `is_hf_downloadable` — for HF URLs only: true if the + server can currently fetch the file with its stored auth + state, false if gated and lacking access, null otherwise + + Designed for 1 Hz polling. `file_size` and the intrinsic + "is this model gated" check are cached server-side per URL; + `is_hf_downloadable` is recomputed per call so license + acceptance and login/logout transitions show up within one + poll interval without any client-side cache plumbing. + operationId: postModelsAvailabilityStatus + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AvailabilityStatusRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AvailabilityStatusResponse' + description: Per-model status and metadata. + "400": + description: Malformed request body. + summary: Unified per-model status + metadata for the polling UI + tags: + - model /api/node_replacements: get: description: | @@ -5087,3 +5401,5 @@ tags: name: queue - description: Job lifecycle queries name: job + - description: Server-side model availability and downloads + name: model diff --git a/requirements.txt b/requirements.txt index 0c8b1888e..830213707 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ numpy>=1.25.0 einops transformers>=4.50.3 tokenizers>=0.13.3 +huggingface_hub sentencepiece safetensors>=0.4.2 aiohttp>=3.11.8 diff --git a/server.py b/server.py index 361850f38..1a6850b87 100644 --- a/server.py +++ b/server.py @@ -47,6 +47,7 @@ from app.assets.seeder import asset_seeder from app.assets.api.routes import register_assets_routes from app.assets.services.ingest import register_file_in_place from app.assets.services.asset_management import resolve_hash_to_path +from app.model_downloader.api.routes import register_routes as register_model_downloader_routes from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -256,6 +257,7 @@ class PromptServer(): else: register_assets_routes(self.app) asset_seeder.disable() + register_model_downloader_routes(self.app) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None diff --git a/tests-unit/app_test/hf_auth_test.py b/tests-unit/app_test/hf_auth_test.py new file mode 100644 index 000000000..39d0d6ee7 --- /dev/null +++ b/tests-unit/app_test/hf_auth_test.py @@ -0,0 +1,567 @@ +"""Unit tests for the HuggingFace auth subsystem. + +Covers: + - token store: save/load roundtrip, chmod 0600, atomic write, delete + - eligibility under various CLI-arg combinations + - URL parsing (huggingface.co host detection + repo_id extraction) + - HF-aware gated_detection.probe_url (mocked auth_check) + - HF auth routes (token status, login start with eligibility gate, logout) + - PKCE primitives + authorize URL shape + +The OAuth callback server itself isn't exercised end-to-end here — that +requires a real HF server. We test the components (state checking, +URL building, code-exchange request shape) instead. +""" + +from __future__ import annotations + +import json +import os +import stat +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp import web + +from app.model_downloader.api.routes import register_routes +from app.model_downloader.hf_auth import oauth +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore +from app.model_downloader.hf_auth.token_store import ( + EXPIRY_BUFFER_SECS, + Token, + delete_token, + load_token, + save_token, +) +from app.model_downloader.hf_url import is_hf_url, repo_id_from_url + +pytestmark = pytest.mark.asyncio + + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def patched_user_dir(tmp_path): + """Redirect ``folder_paths.get_user_directory`` so the token file + lands in an isolated tmp_path instead of the real user dir.""" + user_dir = tmp_path / "user" + user_dir.mkdir() + with patch("folder_paths.get_user_directory", return_value=str(user_dir)): + yield user_dir + + +@pytest.fixture +def fresh_auth_store(): + """Wipe singleton state between tests: auth + probe caches.""" + from app.model_downloader import gated_detection + + HF_AUTH_STORE._token = None + HF_AUTH_STORE._loaded_from_disk = False + gated_detection.clear_caches_for_tests() + yield HF_AUTH_STORE + HF_AUTH_STORE._token = None + HF_AUTH_STORE._loaded_from_disk = False + gated_detection.clear_caches_for_tests() + + +@pytest.fixture +def app(patched_user_dir, fresh_auth_store): + app = web.Application() + register_routes(app) + return app + + +# --------------------------------------------------------------------------- # +# URL parsing +# --------------------------------------------------------------------------- # + + +def test_is_hf_url_recognises_huggingface_co(): + assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors") + assert is_hf_url("https://huggingface.co/abc") + assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors") + assert not is_hf_url("https://civitai.com/x.safetensors") + + +def test_repo_id_from_url_extracts_org_and_repo(): + url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors" + assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR" + + +def test_repo_id_from_url_handles_nested_path(): + url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors" + assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3" + + +def test_repo_id_from_url_returns_none_for_non_hf(): + assert repo_id_from_url("https://civitai.com/x.safetensors") is None + + +def test_repo_id_from_url_returns_none_for_non_resolve_paths(): + assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None + assert repo_id_from_url("https://huggingface.co/org") is None + + +# --------------------------------------------------------------------------- # +# Token store +# --------------------------------------------------------------------------- # + + +def test_token_store_roundtrip(patched_user_dir): + tok = Token( + access_token="hf_abc", + refresh_token="rf_def", + expires_at=9999999999.0, + scope="openid profile", + ) + save_token(tok) + loaded = load_token() + assert loaded == tok + + +def test_token_store_writes_0600(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=0.0) + save_token(tok) + path = os.path.join(patched_user_dir, "hf_auth_token.json") + mode = stat.S_IMODE(os.stat(path).st_mode) + # On Windows we silently no-op chmod; allow either the intended + # mode or whatever umask the OS gave us. + if os.name == "posix": + assert mode == 0o600 + + +def test_token_store_delete_removes_file(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=0.0) + save_token(tok) + delete_token() + path = os.path.join(patched_user_dir, "hf_auth_token.json") + assert not os.path.exists(path) + # Idempotent: second delete is fine. + delete_token() + + +def test_token_store_load_returns_none_for_missing_file(patched_user_dir): + assert load_token() is None + + +def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir): + path = os.path.join(patched_user_dir, "hf_auth_token.json") + with open(path, "w") as f: + f.write("not json {") + assert load_token() is None + + +def test_token_is_valid_uses_buffer(patched_user_dir): + import time + + fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600) + nearly_expired = Token( + access_token="x", + refresh_token=None, + expires_at=time.time() + EXPIRY_BUFFER_SECS - 1, + ) + assert fresh.is_valid() + assert not nearly_expired.is_valid() + + +# --------------------------------------------------------------------------- # +# Auth store +# --------------------------------------------------------------------------- # + + +def test_auth_store_loads_lazily(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + save_token(tok) + store = HfAuthStore() + assert store.has_token() + assert store.get_token_sync() == tok + + +def test_auth_store_set_persists(patched_user_dir): + store = HfAuthStore() + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + store.set_token(tok) + # Token is on disk now — a fresh store sees it. + assert HfAuthStore().get_token_sync() == tok + + +def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir): + store = HfAuthStore() + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + store.set_token(tok) + store.clear() + assert not store.has_token() + assert HfAuthStore().get_token_sync() is None + + +async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir): + store = HfAuthStore() + import time + + tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600) + store.set_token(tok) + fetched = await store.get_valid_token() + assert fetched == tok + + +async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir): + store = HfAuthStore() + import time + + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + refreshed = Token( + access_token="new", + refresh_token="rf", + expires_at=time.time() + 3600, + ) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(return_value=refreshed), + ): + result = await store.get_valid_token() + assert result == refreshed + + +async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir): + store = HfAuthStore() + import time + + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(side_effect=RuntimeError("HF down")), + ): + result = await store.get_valid_token() + assert result is None + + +# --------------------------------------------------------------------------- # +# Eligibility +# --------------------------------------------------------------------------- # + + +@pytest.mark.parametrize( + "listen,multi_user,expected", + [ + ("127.0.0.1", False, True), + ("127.0.0.1", True, False), # multi-user disables it + ("0.0.0.0", False, False), # bind-all is not loopback + ("0.0.0.0", True, False), + ("192.168.1.5", False, False), # LAN address + ("::1", False, True), # IPv6 loopback + ], +) +def test_eligibility(listen, multi_user, expected, monkeypatch): + from app.model_downloader.hf_auth import eligibility + from comfy.cli_args import args + + monkeypatch.setattr(args, "listen", listen) + monkeypatch.setattr(args, "multi_user", multi_user) + assert eligibility.is_hf_auth_eligible() is expected + + +# --------------------------------------------------------------------------- # +# gated_detection HF probe +# --------------------------------------------------------------------------- # + + +async def test_probe_url_hf_public(fresh_auth_store): + """auth_check succeeds with no token → is_hf_downloadable = True.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch("app.model_downloader.gated_detection._auth_check_sync"), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=1024), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is True + assert result.file_size == 1024 + + +async def test_probe_url_hf_gated_no_access(fresh_auth_store): + """auth_check raises GatedRepoError → is_hf_downloadable = False.""" + from huggingface_hub.errors import GatedRepoError + + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors" + fake_response = MagicMock(status_code=403) + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=GatedRepoError("gated", response=fake_response), + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is False + + +async def test_probe_url_non_hf_skips_auth_check(): + """Non-HF URLs never call auth_check; is_hf_downloadable stays None.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://civitai.com/api/download/models/1.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=2048), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is None + assert result.file_size == 2048 + mocked.assert_not_called() + + +async def test_is_gated_cached_across_calls(fresh_auth_store): + """Intrinsic ``is_gated`` should be determined exactly once per URL. + + Subsequent ``probe_url`` calls for the same URL must not re-issue + the null-token auth_check — that's the whole point of the cache.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync" + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=1024), + ): + await probe_url(url) + await probe_url(url) + await probe_url(url) + # Three probe_url calls × public-only-needs-1-auth_check = 1 call total. + assert mocked.call_count == 1 + + +async def test_file_size_cached_across_calls(fresh_auth_store): + """Once a successful HEAD lands, subsequent calls don't re-HEAD.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync" + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=2048), + ) as size_probe: + r1 = await probe_url(url) + r2 = await probe_url(url) + assert r1.file_size == 2048 + assert r2.file_size == 2048 + assert size_probe.call_count == 1 + + +async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store): + """When ``is_hf_downloadable`` is False we must NOT HEAD the URL — + otherwise a 401-due-to-gating would land as a cached ``None`` that + survives a later successful login.""" + from app.model_downloader.gated_detection import probe_url + from huggingface_hub.errors import GatedRepoError + from unittest.mock import MagicMock + + url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors" + fake_resp = MagicMock(status_code=403) + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=GatedRepoError("gated", response=fake_resp), + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ) as size_probe: + result = await probe_url(url) + assert result.is_hf_downloadable is False + assert result.file_size is None + assert size_probe.call_count == 0 + + +async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir): + """For a gated URL, auth_check runs twice: once with token=None to + determine the intrinsic ``is_gated`` flag (cached forever), and once + with the stored access_token to determine ``is_hf_downloadable`` for + the current user.""" + from app.model_downloader import gated_detection + from app.model_downloader.gated_detection import probe_url + from huggingface_hub.errors import GatedRepoError + from unittest.mock import MagicMock + + gated_detection.clear_caches_for_tests() + fresh_auth_store.set_token(Token( + access_token="hf_test_token", + refresh_token=None, + expires_at=9999999999.0, + )) + url = "https://huggingface.co/private/repo/resolve/main/x.safetensors" + + fake_resp = MagicMock(status_code=403) + + def fake_auth_check(repo_id, token): + # Null-token call → repo is gated. Subsequent call with the real + # token succeeds (user has access). + if token is None: + raise GatedRepoError("gated", response=fake_resp) + + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=fake_auth_check, + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ): + result = await probe_url(url) + + # is_hf_downloadable should be True (token-authed call succeeded). + assert result.is_hf_downloadable is True + # Two calls: (repo_id, None) then (repo_id, ). + assert mocked.call_count == 2 + assert mocked.call_args_list[0].args == ("private/repo", None) + assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token") + + +# --------------------------------------------------------------------------- # +# OAuth primitives +# --------------------------------------------------------------------------- # + + +def test_make_pkce_returns_distinct_high_entropy_values(): + verifier1, challenge1, state1 = oauth._make_pkce() + verifier2, challenge2, state2 = oauth._make_pkce() + assert verifier1 != verifier2 + assert challenge1 != challenge2 + assert state1 != state2 + # Verifier should be at least 43 chars per PKCE spec. + assert len(verifier1) >= 43 + + +def test_build_authorize_url_includes_pkce_and_state(): + url = oauth._build_authorize_url("challenge123", "state456") + assert url.startswith(oauth.AUTHORIZE_URL) + assert "client_id=" + oauth.HF_CLIENT_ID in url + assert "code_challenge=challenge123" in url + assert "code_challenge_method=S256" in url + assert "state=state456" in url + assert "response_type=code" in url + + +# --------------------------------------------------------------------------- # +# Routes +# --------------------------------------------------------------------------- # + + +async def test_hf_auth_token_status_empty(aiohttp_client, app): + """No token set → token_available=false, username=null.""" + client = await aiohttp_client(app) + resp = await client.get("/api/hf-auth-token-status") + assert resp.status == 200 + data = await resp.json() + assert data == {"token_available": False, "username": None} + + +async def test_hf_auth_token_status_with_token( + aiohttp_client, app, fresh_auth_store, patched_user_dir +): + """Token present, whoami works → username is returned.""" + fresh_auth_store.set_token(Token( + access_token="x", refresh_token=None, expires_at=9999999999.0, + )) + with patch( + "app.model_downloader.api.routes._whoami_username", + return_value="alice", + ): + client = await aiohttp_client(app) + resp = await client.get("/api/hf-auth-token-status") + assert resp.status == 200 + assert (await resp.json()) == {"token_available": True, "username": "alice"} + + +async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch): + """Not loopback / multi-user → 403.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: False, + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 403 + assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE" + + +async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch): + """Eligible + first attempt → 200 with authorize_url.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 200 + assert (await resp.json())["authorize_url"].startswith( + "https://huggingface.co/oauth/authorize" + ) + + +async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch): + """Lock already held → 409.""" + from app.model_downloader.hf_auth.oauth import OAuthInProgressError + + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(side_effect=OAuthInProgressError()), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS" + + +async def test_hf_auth_logout_clears_store( + aiohttp_client, app, fresh_auth_store, patched_user_dir +): + fresh_auth_store.set_token(Token( + access_token="x", refresh_token=None, expires_at=9999999999.0, + )) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-logout") + assert resp.status == 200 + assert (await resp.json()) == {"logged_out": True} + assert not fresh_auth_store.has_token() + + +async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch): + """The availability response embeds {token_available, eligible}.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/models-availability-status", + json={"models": {}}, + ) + assert resp.status == 200 + data = await resp.json() + assert "hf_auth" in data + assert data["hf_auth"] == {"token_available": False, "eligible": True} diff --git a/tests-unit/app_test/model_downloader_test.py b/tests-unit/app_test/model_downloader_test.py new file mode 100644 index 000000000..581700793 --- /dev/null +++ b/tests-unit/app_test/model_downloader_test.py @@ -0,0 +1,504 @@ +"""Unit tests for the server-side model download subsystem. + +Covers the pieces that don't require talking to a real network: + + - path parsing & allowlist (pure functions) + - DownloadServer registry lifecycle (in-memory state) + - API routes via aiohttp_client + folder_paths/probe_url patches + +Streaming downloads themselves are exercised indirectly — the route-level +tests stub out the network probe so we can verify the gating logic in +``download_models`` without making real HTTP calls. +""" + +from __future__ import annotations + +import asyncio +import os +from unittest.mock import patch, AsyncMock + +import pytest +from aiohttp import web + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.api.routes import register_routes +from app.model_downloader.download_server import DownloadServer +from app.model_downloader.gated_detection import MetadataProbeResult +from app.model_downloader.paths import ( + InvalidModelId, + parse_model_id, + resolve_destination, + resolve_existing, +) + +# Global asyncio mark: the sync tests below trigger a cosmetic +# PytestWarning for each one because pytest-asyncio applies the mark +# indiscriminately. Other tests in this repo (see custom_node_manager_test.py) +# use the same pattern. The warnings are noise, not failures. +pytestmark = pytest.mark.asyncio + + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def model_root(tmp_path): + """A fake ``models/`` root with two registered folder types.""" + loras_dir = tmp_path / "loras" + checkpoints_dir = tmp_path / "checkpoints" + loras_dir.mkdir() + checkpoints_dir.mkdir() + return tmp_path, loras_dir, checkpoints_dir + + +@pytest.fixture +def patched_folder_paths(model_root): + """Point folder_paths at our fake roots for the duration of one test.""" + _root, loras_dir, checkpoints_dir = model_root + mapping = { + "loras": ([str(loras_dir)], {".safetensors"}), + "checkpoints": ([str(checkpoints_dir)], {".safetensors"}), + } + with patch( + "folder_paths.folder_names_and_paths", mapping + ), patch( + "folder_paths.get_folder_paths", + side_effect=lambda name: mapping.get(name, ([], set()))[0], + ): + yield mapping + + +@pytest.fixture +def fresh_download_server(): + """Reset the module-level singleton between tests so registry state + doesn't leak across tests sharing the singleton.""" + from app.model_downloader.download_server import DOWNLOAD_SERVER + + DOWNLOAD_SERVER.reset_for_tests() + yield DOWNLOAD_SERVER + DOWNLOAD_SERVER.reset_for_tests() + + +@pytest.fixture +def app(patched_folder_paths, fresh_download_server): + app = web.Application() + register_routes(app) + return app + + +# --------------------------------------------------------------------------- # +# Pure helpers: allowlist + path parsing +# --------------------------------------------------------------------------- # + + +def test_allowlist_accepts_hf_safetensors(): + assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors") + + +def test_allowlist_accepts_civitai_pth(): + assert is_url_allowed("https://civitai.com/api/download/models/123.pth") + + +def test_allowlist_rejects_unknown_host(): + assert not is_url_allowed("https://example.com/x.safetensors") + + +def test_allowlist_rejects_api_path_on_hf(): + # On an allowlisted host but not pointing at a model file. + assert not is_url_allowed("https://huggingface.co/api/models") + + +def test_allowlist_rejects_non_https_except_localhost(): + assert not is_url_allowed("http://huggingface.co/x/y.safetensors") + assert is_url_allowed("http://localhost:8000/x.safetensors") + + +def test_parse_model_id_valid(patched_folder_paths): + assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors") + + +def test_parse_model_id_rejects_traversal(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("../etc/passwd") + + +def test_parse_model_id_rejects_unknown_folder(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("nope/x.safetensors") + + +def test_parse_model_id_rejects_double_slash(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("loras/sub/x.safetensors") + + +def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths): + _root, loras_dir, _ = model_root + target = loras_dir / "foo.safetensors" + target.write_bytes(b"x") + assert resolve_existing("loras/foo.safetensors") == str(target) + + +def test_resolve_existing_returns_none_when_absent(patched_folder_paths): + assert resolve_existing("loras/missing.safetensors") is None + + +def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths): + _root, loras_dir, _ = model_root + final, tmp = resolve_destination("loras/foo.safetensors") + assert final == str(loras_dir / "foo.safetensors") + assert tmp == final + ".tmp" + + +# --------------------------------------------------------------------------- # +# DownloadServer registry: lifecycle, races, cancellation epoch semantics +# --------------------------------------------------------------------------- # + + +def test_register_is_exclusive(): + server = DownloadServer() + s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a") + s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b") + assert s1 is not None + assert s2 is None + assert server.is_downloading("loras/x.safetensors") + + +def test_cancel_removes_session(): + server = DownloadServer() + server.try_register("loras/x.safetensors", "https://huggingface.co/a") + assert server.cancel("loras/x.safetensors") is True + assert not server.is_downloading("loras/x.safetensors") + + +def test_cancel_returns_false_when_absent(): + server = DownloadServer() + assert server.cancel("loras/never.safetensors") is False + + +def test_finish_only_clears_matching_epoch(): + """If a session is cancelled and a new one for the same id is + registered, ``finish`` from the original worker must not evict the + newer session.""" + server = DownloadServer() + s_old = server.try_register("loras/x.safetensors", "u1") + server.cancel("loras/x.safetensors") + s_new = server.try_register("loras/x.safetensors", "u2") + assert s_new is not None and s_new.epoch != s_old.epoch + # Old worker's late finish() is a no-op: + server.finish(s_old) + assert server.is_downloading("loras/x.safetensors") + server.finish(s_new) + assert not server.is_downloading("loras/x.safetensors") + + +def test_is_active_follows_cancellation(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + assert server.is_active(s) + server.cancel("loras/x.safetensors") + assert not server.is_active(s) + + +def test_update_progress_tracks_fraction(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + server.update_progress(s, 50, 100) + snap = server.snapshot()["loras/x.safetensors"] + assert snap.bytes_downloaded == 50 + assert snap.total_bytes == 100 + assert snap.progress == 0.5 + + +def test_update_progress_with_unknown_total_keeps_progress_none(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + server.update_progress(s, 50, None) + assert server.snapshot()["loras/x.safetensors"].progress is None + + +def test_cleanup_orphan_tmp_files(model_root): + """Orphan .tmp left by a crashed download must be swept on first use.""" + _root, loras_dir, _ = model_root + orphan = loras_dir / "stale.safetensors.tmp" + orphan.write_bytes(b"partial") + mapping = {"loras": ([str(loras_dir)], {".safetensors"})} + with patch("folder_paths.folder_names_and_paths", mapping), patch( + "folder_paths.get_folder_paths", + side_effect=lambda name: mapping.get(name, ([], set()))[0], + ): + server = DownloadServer() + assert orphan.exists(), "sweep must not run at construction time" + server.sweep_orphan_tmp_files() + assert not orphan.exists() + # Idempotent — a second call is a cheap no-op. + server.sweep_orphan_tmp_files() + + +# --------------------------------------------------------------------------- # +# Route: POST /api/models-availability-status +# --------------------------------------------------------------------------- # + + +async def test_availability_partitions_correctly( + aiohttp_client, app, model_root, fresh_download_server +): + _root, loras_dir, _ = model_root + (loras_dir / "present.safetensors").write_bytes(b"x") + fresh_download_server.try_register( + "loras/inflight.safetensors", "http://localhost:8000/x.safetensors" + ) + client = await aiohttp_client(app) + + # Stub probes — we're testing state assignment, not network calls. + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult( + file_size=None, is_hf_downloadable=None, + )), + ): + body = { + "models": { + "loras/present.safetensors": "http://localhost:8000/p.safetensors", + "loras/missing.safetensors": "http://localhost:8000/m.safetensors", + "loras/inflight.safetensors": "http://localhost:8000/x.safetensors", + } + } + resp = await client.post("/api/models-availability-status", json=body) + assert resp.status == 200 + data = await resp.json() + models = data["models"] + assert models["loras/present.safetensors"]["state"] == "available" + assert models["loras/missing.safetensors"]["state"] == "missing" + assert models["loras/inflight.safetensors"]["state"] == "downloading" + assert "hf_auth" in data + + +async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app): + client = await aiohttp_client(app) + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult( + file_size=None, is_hf_downloadable=None, + )), + ): + resp = await client.post( + "/api/models-availability-status", + json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["models"]["../etc/passwd"]["state"] == "missing" + + +# --------------------------------------------------------------------------- # +# Route: POST /api/download-models — precondition gating +# --------------------------------------------------------------------------- # + + +async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}}, + ) + assert resp.status == 400 + err = (await resp.json())["error"] + assert err["code"] == "URL_NOT_ALLOWED" + + +async def test_download_rejects_already_available( + aiohttp_client, app, model_root +): + _root, loras_dir, _ = model_root + (loras_dir / "x.safetensors").write_bytes(b"x") + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE" + + +async def test_download_rejects_already_downloading( + aiohttp_client, app, fresh_download_server +): + fresh_download_server.try_register( + "loras/x.safetensors", "https://huggingface.co/u.safetensors" + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING" + + +async def test_download_rejects_gated_model(aiohttp_client, app): + client = await aiohttp_client(app) + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)), + ): + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 400 + assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE" + + +async def test_download_rejects_invalid_model_id(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}}, + ) + assert resp.status == 400 + assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID" + + +async def test_download_atomic_failure_does_not_register_partial( + aiohttp_client, app, model_root, fresh_download_server +): + """If one model in a batch fails, none get registered.""" + _root, loras_dir, _ = model_root + (loras_dir / "already.safetensors").write_bytes(b"x") + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={ + "models": { + "loras/already.safetensors": + "https://huggingface.co/a/b/resolve/main/already.safetensors", + "loras/new.safetensors": + "https://huggingface.co/a/b/resolve/main/new.safetensors", + } + }, + ) + assert resp.status == 409 + # The "new" model should not have been registered as part of the + # failed batch. + assert not fresh_download_server.is_downloading("loras/new.safetensors") + + +async def test_download_schedules_when_all_preconditions_pass( + aiohttp_client, app, fresh_download_server +): + """Verify the precondition pass, registration pass, and async + scheduling all wire up correctly. We patch the streamer to avoid + real HTTP while still letting the route execute end-to-end.""" + started = asyncio.Event() + finish_signal = asyncio.Event() + + async def fake_stream(session): + started.set() + await finish_signal.wait() + from app.model_downloader.download_server import DOWNLOAD_SERVER + DOWNLOAD_SERVER.finish(session) + return "/dev/null" + + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)), + ), patch( + "app.model_downloader.downloader.stream_to_disk", new=fake_stream + ): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/new.safetensors": + "https://huggingface.co/a/b/resolve/main/new.safetensors" + }}, + ) + assert resp.status == 202 + body = await resp.json() + assert body["accepted"] is True + assert body["scheduled"] == ["loras/new.safetensors"] + # Wait for the worker to actually start. + await asyncio.wait_for(started.wait(), timeout=2.0) + assert fresh_download_server.is_downloading("loras/new.safetensors") + finish_signal.set() + + +# --------------------------------------------------------------------------- # +# Route: POST /api/cancel-model-download-session +# --------------------------------------------------------------------------- # + + +async def test_cancel_removes_active_session( + aiohttp_client, app, fresh_download_server +): + fresh_download_server.try_register( + "loras/x.safetensors", "https://huggingface.co/u.safetensors" + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/cancel-model-download-session", + json={"model_id": "loras/x.safetensors"}, + ) + assert resp.status == 200 + assert (await resp.json())["cancelled"] is True + assert not fresh_download_server.is_downloading("loras/x.safetensors") + + +async def test_cancel_returns_404_when_none(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/cancel-model-download-session", + json={"model_id": "loras/nothing.safetensors"}, + ) + assert resp.status == 404 + assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING" + + +# --------------------------------------------------------------------------- # +# Unified availability response embeds metadata per id +# --------------------------------------------------------------------------- # + + +async def test_availability_embeds_metadata(aiohttp_client, app): + """``file_size`` + ``is_hf_downloadable`` come back on the same + request as the state — no separate metadata endpoint.""" + results = { + "https://huggingface.co/a/b/resolve/main/free.safetensors": + MetadataProbeResult(file_size=1024, is_hf_downloadable=True), + "https://huggingface.co/g/r/resolve/main/gated.safetensors": + MetadataProbeResult(file_size=None, is_hf_downloadable=False), + } + + async def fake_probe(url): + return results[url] + + with patch( + "app.model_downloader.api.routes.probe_url", new=fake_probe + ): + client = await aiohttp_client(app) + resp = await client.post( + "/api/models-availability-status", + json={ + "models": { + "loras/free.safetensors": + "https://huggingface.co/a/b/resolve/main/free.safetensors", + "loras/gated.safetensors": + "https://huggingface.co/g/r/resolve/main/gated.safetensors", + } + }, + ) + assert resp.status == 200 + models = (await resp.json())["models"] + assert models["loras/free.safetensors"]["file_size"] == 1024 + assert models["loras/free.safetensors"]["is_hf_downloadable"] is True + assert models["loras/gated.safetensors"]["file_size"] is None + assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False