From 351119eb050a74724795c18a1db5f2d0eafa99a3 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 24 Jun 2026 09:06:22 +0300 Subject: [PATCH] feat(model_downloader): add server-side model downloads with gated-repo support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lets ComfyUI fetch the models a workflow needs directly on the server, so users no longer have to locate each file and drop it into the correct folder by hand. Crucially it supports gated HuggingFace repositories: the user logs in once via HuggingFace, after which the server can download models that require license acceptance or authentication — previously a manual, error-prone step. The frontend can surface per-model availability and download progress through the accompanying API. --- app/model_downloader/allowlist.py | 51 ++ app/model_downloader/api/routes.py | 359 ++++++++++ app/model_downloader/api/schemas_in.py | 41 ++ app/model_downloader/api/schemas_out.py | 81 +++ app/model_downloader/download_server.py | 179 +++++ app/model_downloader/downloader.py | 216 ++++++ app/model_downloader/gated_detection.py | 245 +++++++ app/model_downloader/hf_auth/auth_store.py | 121 ++++ app/model_downloader/hf_auth/eligibility.py | 55 ++ app/model_downloader/hf_auth/oauth.py | 301 ++++++++ app/model_downloader/hf_auth/token_store.py | 94 +++ app/model_downloader/hf_url.py | 41 ++ app/model_downloader/http_client.py | 63 ++ app/model_downloader/paths.py | 111 +++ tests-unit/app_test/hf_auth_test.py | 708 +++++++++++++++++++ tests-unit/app_test/model_downloader_test.py | 514 ++++++++++++++ 16 files changed, 3180 insertions(+) create mode 100644 app/model_downloader/allowlist.py create mode 100644 app/model_downloader/api/routes.py create mode 100644 app/model_downloader/api/schemas_in.py create mode 100644 app/model_downloader/api/schemas_out.py create mode 100644 app/model_downloader/download_server.py create mode 100644 app/model_downloader/downloader.py create mode 100644 app/model_downloader/gated_detection.py create mode 100644 app/model_downloader/hf_auth/auth_store.py create mode 100644 app/model_downloader/hf_auth/eligibility.py create mode 100644 app/model_downloader/hf_auth/oauth.py create mode 100644 app/model_downloader/hf_auth/token_store.py create mode 100644 app/model_downloader/hf_url.py create mode 100644 app/model_downloader/http_client.py create mode 100644 app/model_downloader/paths.py create mode 100644 tests-unit/app_test/hf_auth_test.py create mode 100644 tests-unit/app_test/model_downloader_test.py diff --git a/app/model_downloader/allowlist.py b/app/model_downloader/allowlist.py new file mode 100644 index 000000000..e548bd994 --- /dev/null +++ b/app/model_downloader/allowlist.py @@ -0,0 +1,51 @@ +"""URL allowlist for server-side model fetches. + +Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows +agree on which URLs are eligible for download. Server-side allowlisting is +the primary SSRF defense for this subsystem — workflow JSON is untrusted +input (anyone can hand-craft one), so we never let the server fetch URLs +outside this list. +""" + +from urllib.parse import urlparse + +# Frontend parity: ``missingModelDownload-*.js`` exports the same triple +# (Civitai / HuggingFace / localhost). Keyed by exact hostname → allowed +# schemes, and matched against the *parsed* host (not a raw string prefix), +# so URL-userinfo tricks can't slip past — see ``is_url_allowed``. +_ALLOWED_HOSTS = { + "huggingface.co": {"https"}, + "civitai.com": {"https"}, + "localhost": {"http"}, + "127.0.0.1": {"http"}, +} + +# Frontend parity: same set as ``a = [...]`` in the bundle. +_ALLOWED_MODEL_EXTENSIONS = ( + ".safetensors", + ".sft", + ".ckpt", + ".pth", + ".pt", +) + + +def is_url_allowed(url: str) -> bool: + """Check whether ``url`` is permitted as a server-side download source. + + True only when the parsed host + scheme are allowlisted AND the path ends + in a model extension. Matching on ``parsed.hostname`` (not a string prefix) + defeats userinfo tricks like ``http://127.0.0.1:80@169.254.169.254/x.safetensors``, + whose real host is ``169.254.169.254``; the extension check rejects non-model + URLs on allowed hosts (e.g. ``huggingface.co/api/...``). + """ + if not isinstance(url, str) or not url: + return False + try: + parsed = urlparse(url) + except ValueError: + return False + host = parsed.hostname + if host is None or parsed.scheme not in _ALLOWED_HOSTS.get(host, ()): + return False + return any(parsed.path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS) diff --git a/app/model_downloader/api/routes.py b/app/model_downloader/api/routes.py new file mode 100644 index 000000000..b07e2a14c --- /dev/null +++ b/app/model_downloader/api/routes.py @@ -0,0 +1,359 @@ +"""Aiohttp routes for the server-side model download subsystem. + +Endpoint surface (all under ``/api/``, all kebab-case): + + - ``POST /api/models-availability-status`` — bulk status + metadata query. + - ``POST /api/download-models`` — start a batch of downloads. + - ``POST /api/cancel-model-download-session`` — cancel a single in-flight one. + - ``GET /api/hf-auth-token-status`` — current HF login state. + - ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow. + - ``POST /api/hf-auth-logout`` — drop the stored HF token. + +The contract is intentionally narrow: only model_ids of the form +``/`` (validated via ``app.model_downloader.paths``) +are accepted, and only URLs on the same allowlist the frontend already +uses (HuggingFace, Civitai, localhost) can be fetched. Both are required +to keep the server out of the SSRF business for this feature. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any, Literal, Optional + +from aiohttp import web +from pydantic import BaseModel, ValidationError + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.download_server import ( + DOWNLOAD_SERVER, + DownloadSession, +) +from app.model_downloader.downloader import schedule_batch +from app.model_downloader.gated_detection import probe_url +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible +from app.model_downloader.hf_auth.oauth import ( + OAuthCallbackError, + OAuthInProgressError, + start_login_flow, +) +from app.model_downloader.paths import ( + InvalidModelId, + parse_model_id, + resolve_existing, +) +from app.model_downloader.api import schemas_in, schemas_out + +ROUTES = web.RouteTableDef() + + +def register_routes(app: web.Application) -> None: + """Wire the model-downloader routes into the running aiohttp app. + + Called once from ``server.py`` during ``PromptServer`` startup. + """ + app.add_routes(ROUTES) + + +# ----- response helpers (same envelope as app/assets/api/routes.py) ----- + + +ErrorCode = Literal[ + "INVALID_JSON", + "INVALID_BODY", + "EMPTY_REQUEST", + "INVALID_MODEL_ID", + "URL_NOT_ALLOWED", + "ALREADY_AVAILABLE", + "ALREADY_DOWNLOADING", + "MODEL_NOT_DOWNLOADABLE", + "NOT_DOWNLOADING", + "HF_AUTH_NOT_ELIGIBLE", + "HF_AUTH_IN_PROGRESS", + "HF_AUTH_START_FAILED", +] + + +def _error(status: int, code: ErrorCode, message: str, details: dict | None = None) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _validation_error(code: ErrorCode, ve: ValidationError) -> web.Response: + return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())}) + + +def _ok(payload: BaseModel, status: int = 200) -> web.Response: + return web.json_response( + payload.model_dump(mode="json", exclude_none=False), + status=status, + ) + + +async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any: + """Parse a JSON body into a pydantic model or raise a 400 response.""" + try: + raw = await request.json() + except json.JSONDecodeError: + return _error(400, "INVALID_JSON", "Request body must be valid JSON.") + try: + return model.model_validate(raw) + except ValidationError as ve: + return _validation_error("INVALID_BODY", ve) + + +# ----- 1. availability status (unified: state + metadata per id) ----- + + +@ROUTES.post("/api/models-availability-status") +async def models_availability_status(request: web.Request) -> web.Response: + """Return per-id ``{state, progress, file_size, is_hf_downloadable}``. + + State (``available`` / ``missing`` / ``downloading``) is cheap to + recompute per call. ``file_size`` and ``is_gated`` are cached + server-side per URL. ``is_hf_downloadable`` is recomputed every + call from the current token state — that's what makes login + license + acceptance show up in the UI within one poll cycle without any + frontend cache plumbing. + """ + parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest) + if isinstance(parsed, web.Response): + return parsed + + items = list(parsed.models.items()) + + # Run all probes concurrently; each is internally cached per URL. + probes = await asyncio.gather(*(probe_url(url) for _, url in items)) + + response_models: dict[str, schemas_out.ModelStatusEntry] = {} + for (model_id, _url), probe in zip(items, probes): + try: + parse_model_id(model_id) + except InvalidModelId: + # Ill-formed identifier: report as missing without 400-ing the + # whole batch — the workflow author probably typo'd. + response_models[model_id] = schemas_out.ModelStatusEntry( + state="missing", + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + continue + + active = DOWNLOAD_SERVER.get(model_id) + if active is not None: + response_models[model_id] = schemas_out.ModelStatusEntry( + state="downloading", + progress=schemas_out.DownloadProgress( + bytes_downloaded=active.bytes_downloaded, + total_bytes=active.total_bytes, + progress=active.progress, + ), + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + continue + + state: schemas_out.ModelState = ( + "available" if resolve_existing(model_id) is not None else "missing" + ) + response_models[model_id] = schemas_out.ModelStatusEntry( + state=state, + file_size=probe.file_size, + is_hf_downloadable=probe.is_hf_downloadable, + ) + + return _ok(schemas_out.AvailabilityStatusResponse( + models=response_models, + hf_auth=schemas_out.HfAuthStatus( + token_available=HF_AUTH_STORE.has_token(), + eligible=is_hf_auth_eligible(), + ), + )) + + +# ----- 2. start downloads ----- + + +@ROUTES.post("/api/download-models") +async def download_models(request: web.Request) -> web.Response: + parsed = await _parse_body(request, schemas_in.DownloadModelsRequest) + if isinstance(parsed, web.Response): + return parsed + + if not parsed.models: + return _error(400, "EMPTY_REQUEST", "No models supplied.") + + # ----- precondition pass: validate everything BEFORE registering anything ----- + # Atomic semantics: if any model fails any precondition (invalid id, + # not allow-listed URL, already on disk, already downloading, or gated), + # the entire request fails and no state is changed. + requested = list(parsed.models.items()) + + for model_id, url in requested: + try: + parse_model_id(model_id) + except InvalidModelId as e: + return _error(400, "INVALID_MODEL_ID", str(e), + {"model_id": model_id}) + + if not is_url_allowed(url): + return _error( + 400, "URL_NOT_ALLOWED", + "Server-side downloads only accept HuggingFace, Civitai, " + "or localhost URLs ending in a known model extension.", + {"model_id": model_id, "url": url}, + ) + + if resolve_existing(model_id) is not None: + return _error(409, "ALREADY_AVAILABLE", + f"Model already exists on disk: {model_id}", + {"model_id": model_id}) + + if DOWNLOAD_SERVER.is_downloading(model_id): + return _error(409, "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress.", + {"model_id": model_id}) + + # Reachability check last — it's the only one that talks to the + # network. Concurrent probes. For HF URLs ``is_hf_downloadable`` + # reflects current token access; for non-HF URLs it's None, and we + # treat that as "no info, proceed". + probes = await asyncio.gather(*(probe_url(url) for _, url in requested)) + for (model_id, url), probe in zip(requested, probes): + if probe.is_hf_downloadable is False: + return _error( + 400, "MODEL_NOT_DOWNLOADABLE", + f"Model {model_id} is gated on HuggingFace and the current " + f"server token (if any) does not grant access.", + {"model_id": model_id, "url": url}, + ) + + # ----- registration pass: try_register is atomic per model_id ----- + # Defensive: another request might have raced past our pre-check + # between the loop above and here. try_register handles that. + sessions: list[DownloadSession] = [] + for model_id, url in requested: + session = DOWNLOAD_SERVER.try_register(model_id, url) + if session is None: + # Race: someone else got in. Roll back what we registered. + for s in sessions: + DOWNLOAD_SERVER.cancel(s.model_id) + return _error(409, "ALREADY_DOWNLOADING", + f"A download for {model_id} is already in progress (race).", + {"model_id": model_id}) + sessions.append(session) + + DOWNLOAD_SERVER.sweep_orphan_tmp_files() + schedule_batch(sessions) + logging.info( + "[model_downloader] scheduled %d downloads: %s", + len(sessions), [s.model_id for s in sessions], + ) + + return _ok(schemas_out.DownloadModelsResponse( + accepted=True, + scheduled=[s.model_id for s in sessions], + ), status=202) + + +# ----- 3. cancel a session ----- + + +@ROUTES.post("/api/cancel-model-download-session") +async def cancel_model_download_session(request: web.Request) -> web.Response: + parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest) + if isinstance(parsed, web.Response): + return parsed + + try: + parse_model_id(parsed.model_id) + except InvalidModelId as e: + return _error(400, "INVALID_MODEL_ID", str(e), {"model_id": parsed.model_id}) + + cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id) + if not cancelled: + return _error(404, "NOT_DOWNLOADING", + f"No active download for {parsed.model_id}.", + {"model_id": parsed.model_id}) + + return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True)) + + +# ----- 4. HuggingFace OAuth status / login start / logout ----- + + +@ROUTES.get("/api/hf-auth-token-status") +async def hf_auth_token_status(request: web.Request) -> web.Response: + """Return whether the server holds a usable HF token + its username. + + Used by the settings UI and (out-of-band) by the frontend on + login completion. ``token_available`` is true even if the cached + access_token is expired — as long as a refresh_token exists, the + user is "logged in" from their perspective. + """ + token_present = HF_AUTH_STORE.has_token() + username: Optional[str] = None + if token_present: + # Resolve the username via whoami. Done in a worker thread because + # huggingface_hub's whoami is synchronous + blocks on a network call. + tok = await HF_AUTH_STORE.get_valid_token() + if tok is not None: + try: + username = await asyncio.to_thread(_whoami_username, tok.access_token) + except Exception as e: + logging.debug("[hf_auth] whoami failed: %s", e) + return _ok(schemas_out.HfAuthTokenStatusResponse( + token_available=token_present, + username=username, + )) + + +def _whoami_username(token: str) -> Optional[str]: + """Sync helper: ask HF for the user name attached to a token.""" + from huggingface_hub import HfApi + info = HfApi().whoami(token=token) + if isinstance(info, dict): + return info.get("name") or info.get("fullname") + return None + + +@ROUTES.post("/api/hf-auth-login-start") +async def hf_auth_login_start(request: web.Request) -> web.Response: + """Begin one OAuth attempt: bind the callback port, return the URL. + + Rejected outright if this deployment isn't eligible (we don't + surface the option on multi-tenant / public-IP installs). + """ + if not is_hf_auth_eligible(): + return _error( + 403, "HF_AUTH_NOT_ELIGIBLE", + "This server is not eligible for interactive HuggingFace login. " + "It must be bound to a loopback address and not running in " + "--multi-user mode.", + ) + try: + url = await start_login_flow() + except OAuthInProgressError: + return _error( + 409, "HF_AUTH_IN_PROGRESS", + "Another HuggingFace login attempt is in progress. Try again " + "after it completes or times out.", + ) + except OAuthCallbackError as e: + return _error( + 503, "HF_AUTH_START_FAILED", + f"Could not start the HuggingFace login flow: {e}", + ) + return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url)) + + +@ROUTES.post("/api/hf-auth-logout") +async def hf_auth_logout(request: web.Request) -> web.Response: + """Drop the in-memory + on-disk HF token.""" + HF_AUTH_STORE.clear() + return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True)) diff --git a/app/model_downloader/api/schemas_in.py b/app/model_downloader/api/schemas_in.py new file mode 100644 index 000000000..d792520ea --- /dev/null +++ b/app/model_downloader/api/schemas_in.py @@ -0,0 +1,41 @@ +"""Request schemas for the model-downloader API. + +Each endpoint accepts a small JSON body. Pydantic enforces the shape at +the boundary; route handlers operate only on validated values past that. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class AvailabilityStatusRequest(BaseModel): + """``POST /api/models-availability-status``. + + Sent by the frontend on each poll. Each entry is ``{model_id: url}``; + the URL is the one declared in ``properties.models[i].url`` in the + workflow JSON and lets the server compute per-id metadata + (``file_size`` + ``is_hf_downloadable``) on the same request. + """ + models: dict[str, str] = Field(default_factory=dict) + + +class DownloadModelsRequest(BaseModel): + """``POST /api/download-models``. + + Same shape as the metadata request — the URL for each model_id. + Returns immediately after validation and scheduling. + """ + models: dict[str, str] = Field(default_factory=dict) + + +class CancelDownloadSessionRequest(BaseModel): + """``POST /api/cancel-model-download-session``.""" + model_id: str + + +__all__ = [ + "AvailabilityStatusRequest", + "DownloadModelsRequest", + "CancelDownloadSessionRequest", +] diff --git a/app/model_downloader/api/schemas_out.py b/app/model_downloader/api/schemas_out.py new file mode 100644 index 000000000..5571ecd28 --- /dev/null +++ b/app/model_downloader/api/schemas_out.py @@ -0,0 +1,81 @@ +"""Response schemas for the model-downloader API.""" + +from __future__ import annotations + +from typing import Literal, Optional + +from pydantic import BaseModel + + +ModelState = Literal["available", "missing", "downloading"] + + +class DownloadProgress(BaseModel): + """Embedded in a model entry when its state is ``downloading``.""" + bytes_downloaded: int + total_bytes: Optional[int] = None + progress: Optional[float] = None # fraction in [0,1]; null until total known + + +class ModelStatusEntry(BaseModel): + """Everything the UI needs to render one row, in one shot. + + ``state`` reflects what the server has on disk + in-flight; ``file_size`` + and ``is_hf_downloadable`` come from probes (intrinsic; cached). + The HF fields are populated for every poll (cached on the server), + so license-acceptance flips show up within one poll interval without + any frontend cache invalidation. + """ + state: ModelState + progress: Optional[DownloadProgress] = None + file_size: Optional[int] = None + # HF-only: True iff the server can fetch this URL with current auth + # state. False iff gated and lacking access. None for non-HF URLs. + is_hf_downloadable: Optional[bool] = None + + +class HfAuthStatus(BaseModel): + """Snapshot of HF login state, embedded in availability response.""" + token_available: bool + eligible: bool + + +class AvailabilityStatusResponse(BaseModel): + models: dict[str, ModelStatusEntry] + hf_auth: HfAuthStatus + + +class DownloadModelsResponse(BaseModel): + accepted: bool + scheduled: list[str] + + +class CancelDownloadSessionResponse(BaseModel): + cancelled: bool + + +class HfAuthTokenStatusResponse(BaseModel): + token_available: bool + username: Optional[str] = None + + +class HfAuthLoginStartResponse(BaseModel): + authorize_url: str + + +class HfAuthLogoutResponse(BaseModel): + logged_out: bool + + +__all__ = [ + "ModelState", + "DownloadProgress", + "ModelStatusEntry", + "HfAuthStatus", + "AvailabilityStatusResponse", + "DownloadModelsResponse", + "CancelDownloadSessionResponse", + "HfAuthTokenStatusResponse", + "HfAuthLoginStartResponse", + "HfAuthLogoutResponse", +] diff --git a/app/model_downloader/download_server.py b/app/model_downloader/download_server.py new file mode 100644 index 000000000..c323309d8 --- /dev/null +++ b/app/model_downloader/download_server.py @@ -0,0 +1,179 @@ +"""Process-wide registry of in-flight model downloads. + +A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running +server-side model fetch. Designed to be safe with multiple concurrent +clients hitting the API: each model_id has at most one active session, +and the API rejects requests that conflict with in-flight downloads. + +Cancellation is cooperative — the download loop checks ``is_active`` on +its own session between chunks and raises ``DownloadCancelled`` when the +session has been removed from the registry. This avoids the complications +of ``Task.cancel()`` from outside the loop while still giving deterministic +rollback semantics (the worker is responsible for deleting its own +``.tmp`` on the cancel path). +""" + +from __future__ import annotations + +import logging +import os +import threading +from dataclasses import dataclass, field +from typing import Optional + +from app.model_downloader.paths import iter_all_tmp_paths + + +class DownloadCancelled(Exception): + """Raised by the streaming loop when its session has been removed + from the registry (cancellation request) and the worker should roll + back its ``.tmp`` file.""" + + +@dataclass +class DownloadSession: + """One in-flight download. + + ``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first + byte arrives and we know whether the response carries a + ``Content-Length``. ``total_bytes`` mirrors that header when present. + """ + model_id: str + url: str + progress: Optional[float] = None + bytes_downloaded: int = 0 + total_bytes: Optional[int] = None + # Sequence number used solely as identity for the cancellation check — + # so that "cancel + restart" doesn't get confused by stale workers. + epoch: int = field(default_factory=lambda: 0) + + +class DownloadServer: + """Singleton registry of active downloads. + + All mutation goes through this object so concurrent route handlers + see a consistent view. The ``_lock`` is a plain threading lock + because the registry is consulted from both the asyncio event-loop + thread (route handlers) and from any worker coroutines spawned to + perform downloads. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._sessions: dict[str, DownloadSession] = {} + self._epoch_counter = 0 + self._orphan_sweep_done = False + + # ----- lifecycle ----- + + def sweep_orphan_tmp_files(self) -> None: + """Idempotently sweep ``*.tmp`` files left by crashed downloads. + + Deferred off the import path so module load doesn't block on + filesystem I/O against potentially-slow mounts. Each route handler + that might create a new ``.tmp`` runs this exactly once. + """ + with self._lock: + if self._orphan_sweep_done: + return + self._orphan_sweep_done = True + for path in iter_all_tmp_paths(): + try: + os.remove(path) + logging.info("[model_downloader] removed orphan tmp file: %s", path) + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + # ----- queries ----- + + def is_downloading(self, model_id: str) -> bool: + with self._lock: + return model_id in self._sessions + + def get(self, model_id: str) -> Optional[DownloadSession]: + with self._lock: + return self._sessions.get(model_id) + + def snapshot(self) -> dict[str, DownloadSession]: + """Return a shallow copy of the current sessions map.""" + with self._lock: + return dict(self._sessions) + + # ----- mutations ----- + + def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]: + """Atomically register a new session iff none exists for ``model_id``. + + Returns the new session on success, ``None`` if a session is already + in flight. Callers must check the return value — the caller is the + sole owner of the session it gets back. + """ + with self._lock: + if model_id in self._sessions: + return None + self._epoch_counter += 1 + session = DownloadSession( + model_id=model_id, + url=url, + epoch=self._epoch_counter, + ) + self._sessions[model_id] = session + return session + + def update_progress( + self, + session: DownloadSession, + bytes_downloaded: int, + total_bytes: Optional[int], + ) -> None: + """Update progress on a session. No-op if the session has been + removed (cancelled) — caller should check ``is_active`` separately.""" + with self._lock: + current = self._sessions.get(session.model_id) + if current is None or current.epoch != session.epoch: + return + current.bytes_downloaded = bytes_downloaded + current.total_bytes = total_bytes + if total_bytes and total_bytes > 0: + current.progress = min(1.0, bytes_downloaded / total_bytes) + + def is_active(self, session: DownloadSession) -> bool: + """True iff this exact session is still the registered one for + its model_id. False after cancellation, after completion, or if + another session has replaced it.""" + with self._lock: + current = self._sessions.get(session.model_id) + return current is not None and current.epoch == session.epoch + + def finish(self, session: DownloadSession) -> None: + """Remove a completed (or cancelled) session from the registry. + + Safe to call multiple times. Only removes if the epoch matches — + we never accidentally evict a *newer* session for the same model_id. + """ + with self._lock: + current = self._sessions.get(session.model_id) + if current is not None and current.epoch == session.epoch: + del self._sessions[session.model_id] + + def reset_for_tests(self) -> None: + """Clear all sessions and reset the epoch counter. Test-only.""" + with self._lock: + self._sessions.clear() + self._epoch_counter = 0 + + def cancel(self, model_id: str) -> bool: + """Remove the session registered for ``model_id``. + + Returns True if there was an active session to cancel. The worker + will discover the cancellation on its next ``is_active`` check + and roll back its ``.tmp`` file. + """ + with self._lock: + if model_id in self._sessions: + del self._sessions[model_id] + return True + return False + + +DOWNLOAD_SERVER = DownloadServer() diff --git a/app/model_downloader/downloader.py b/app/model_downloader/downloader.py new file mode 100644 index 000000000..f30b88aa2 --- /dev/null +++ b/app/model_downloader/downloader.py @@ -0,0 +1,216 @@ +"""Streaming download worker with progress reporting and cancellation. + +Each download writes to ``.tmp`` and atomically renames into +place on success. Between chunks the worker checks the registry for +cancellation (via ``DownloadServer.is_active``) and rolls back its +``.tmp`` on cancel or on any error. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from typing import Optional + +import aiohttp + +from app.model_downloader.download_server import ( + DOWNLOAD_SERVER, + DownloadCancelled, + DownloadSession, +) +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_url import is_hf_url +from app.model_downloader.http_client import get_session, parse_content_length +from app.model_downloader.paths import resolve_destination + + +CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths. +REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120) + + +async def stream_to_disk(session: DownloadSession) -> str: + """Run a single download to completion or cancellation. + + Returns the final on-disk path on success. Removes the ``.tmp`` and + raises on cancellation or failure. The session is finished + (removed from the registry) exactly once, here — callers do not + need to call ``DOWNLOAD_SERVER.finish`` themselves. + """ + final_path, tmp_path = resolve_destination(session.model_id, session.epoch) + os.makedirs(os.path.dirname(final_path), exist_ok=True) + + # Wipe any stale .tmp from a previous failed attempt before we start — + # otherwise a partial body could masquerade as our completed download + # when the rename finally happens. + _remove_if_exists(tmp_path) + + bytes_seen = 0 + try: + http = await get_session() + headers = _auth_headers_for(session.url) + logging.info( + "[model_downloader] starting GET %s (auth=%s)", + session.url, "yes" if "Authorization" in headers else "no", + ) + async with http.get( + session.url, + allow_redirects=True, + timeout=REQUEST_TIMEOUT, + headers=headers, + ) as resp: + if resp.status != 200: + # Capture a snippet of the response body so 4xx/5xx aren't + # opaque in the logs — HF returns JSON or HTML with a + # human-readable reason on failures. + body_snippet = await _read_short(resp) + logging.warning( + "[model_downloader] GET %s failed: status=%d final_url=%s body=%s", + session.url, resp.status, str(resp.url), body_snippet, + ) + raise DownloadError( + f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}", + status=resp.status, + ) + + total = parse_content_length(resp.headers.get("Content-Length")) + DOWNLOAD_SERVER.update_progress(session, 0, total) + + with open(tmp_path, "wb") as f: + async for chunk in resp.content.iter_chunked(CHUNK_SIZE): + # Cancellation check between chunks. Cheap and means + # cancellation latency is bounded by one chunk plus + # one ``write()`` — typically well under a second + # even on slow disks. + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled() + f.write(chunk) + bytes_seen += len(chunk) + DOWNLOAD_SERVER.update_progress(session, bytes_seen, total) + + # Final cancellation check before we promote the .tmp to the real + # filename — avoids the awkward case where cancel arrives during + # the very last chunk and we'd otherwise commit anyway. + if not DOWNLOAD_SERVER.is_active(session): + raise DownloadCancelled() + + # Size verification before commit. aiohttp already raises + # ClientPayloadError on a truncated Content-Length/chunked body, + # but this also catches the HTTP/1.0-style case (no Content-Length + # + Connection: close) where a short read can masquerade as a + # complete download. + if total is not None and bytes_seen != total: + raise DownloadError( + f"size mismatch for {session.model_id}: " + f"got {bytes_seen} of {total} bytes from {session.url}" + ) + + # Atomic rename. os.replace is atomic within the same filesystem, + # which is guaranteed here because tmp lives alongside final_path. + os.replace(tmp_path, final_path) + logging.info( + "[model_downloader] downloaded %s (%d bytes) from %s", + session.model_id, bytes_seen, session.url, + ) + return final_path + + except DownloadCancelled: + logging.info("[model_downloader] cancelled: %s", session.model_id) + _remove_if_exists(tmp_path) + raise + except Exception as e: + logging.warning( + "[model_downloader] failed: %s from %s: %s: %s", + session.model_id, session.url, type(e).__name__, e, + exc_info=True, + ) + _remove_if_exists(tmp_path) + raise + finally: + # In all terminal states (success / cancel / error) drop the + # session from the registry. Idempotent — only removes if we're + # still the live epoch for this model_id. + DOWNLOAD_SERVER.finish(session) + + +class DownloadError(Exception): + """Network / protocol error during a download.""" + + def __init__(self, message: str, status: Optional[int] = None) -> None: + super().__init__(message) + self.status = status + + +async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str: + """Read up to ``limit`` bytes of a response body for logging. + + Used to surface the JSON/HTML reason from an HF non-2xx response in + server logs instead of just the status code. Best-effort: any + error here is swallowed. + """ + try: + raw = await resp.content.read(limit) + return raw.decode("utf-8", errors="replace").strip() + except Exception: + return "" + + +def _auth_headers_for(url: str) -> dict[str, str]: + """Return any auth headers we should add to the GET for ``url``. + + For HuggingFace URLs we inject the user's OAuth access token as a + Bearer header — this is HF's documented way to access gated repos + (see ``huggingface_hub.hf_hub_download``'s wire format). For every + other host we send no extra headers; allowlisted public files + don't need them and we don't want to leak tokens to other hosts. + """ + if not is_hf_url(url): + return {} + tok = HF_AUTH_STORE.get_token_sync() + if tok is None or not tok.access_token: + return {} + return {"Authorization": f"Bearer {tok.access_token}"} + + +def _remove_if_exists(path: str) -> None: + try: + os.remove(path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning("[model_downloader] could not remove %s: %s", path, e) + + +async def run_batch_sequential(sessions: list[DownloadSession]) -> None: + """Run a list of sessions one after the other. + + Each session is independent: a failure or cancellation of one does + not abort the rest. Cancellations are observable via the registry + *before* a given download starts, so a session that's been + pre-cancelled (cancel before the worker reached it) just gets skipped. + """ + for session in sessions: + # If the session got cancelled before its turn, skip without + # touching disk. This is what makes the per-request "sequential + # but cancellable" semantic work. + if not DOWNLOAD_SERVER.is_active(session): + DOWNLOAD_SERVER.finish(session) + continue + try: + await stream_to_disk(session) + except DownloadCancelled: + # Already logged + tmp removed inside stream_to_disk. + continue + except Exception: + # stream_to_disk already logged. Continue with the rest of the batch. + continue + + +def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task: + """Kick off ``run_batch_sequential`` on the running event loop. + + Returned task is fire-and-forget; the API handler returns immediately + after scheduling and clients observe progress via the polling endpoints. + """ + return asyncio.create_task(run_batch_sequential(sessions)) diff --git a/app/model_downloader/gated_detection.py b/app/model_downloader/gated_detection.py new file mode 100644 index 000000000..8749c99ba --- /dev/null +++ b/app/model_downloader/gated_detection.py @@ -0,0 +1,245 @@ +"""Per-URL probes for the unified availability endpoint. + +Three cached/derived facts per URL: + + - ``is_gated`` intrinsic to the model; cached forever once known. + Determined by ``auth_check(repo_id, token=None)``: + ``GatedRepoError`` → True, success → False. + + - ``is_hf_downloadable`` depends on the *current* token; recomputed every + call. For non-gated URLs this is trivially True + (no HF call needed). For gated URLs we run + ``auth_check`` with the stored token each call. + + - ``file_size`` intrinsic to the file. Cached forever once + determined (including ``None`` on transient + failure — we don't retry). We only attempt the + HEAD when we already know the URL is downloadable + to us; that way a failed-because-gated probe + never lands as a cached ``None``. + +Caches are per-process, in-memory; small, no eviction needed for the +workflow-scale (~tens of URLs). Concurrent calls for the same URL +deduplicate via per-URL ``asyncio.Lock``. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from typing import Optional + +import aiohttp +from huggingface_hub import HfApi +from huggingface_hub.errors import ( + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, +) + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_url import is_hf_url, repo_id_from_url +from app.model_downloader.http_client import get_session, parse_content_length + + +_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15) + + +@dataclass +class ProbeResult: + file_size: Optional[int] + is_hf_downloadable: Optional[bool] + + +# --- caches -------------------------------------------------------------- # + + +# url → bool. Whether this URL's HF repo gates access. Intrinsic to the +# model — never changes for a given URL. +_is_gated_cache: dict[str, bool] = {} + +# url → Optional[int]. The file's size in bytes, ``None`` if a probe +# was attempted and produced no answer. **Only populated when we knew +# the URL was downloadable to us at probe time** — so gated-without- +# access never lands a ``None`` here that we'd be stuck with after login. +_file_size_cache: dict[str, Optional[int]] = {} + +# Per-URL locks for single-flight probes — when multiple polls arrive +# in the same tick for the same URL, exactly one of them runs the HF +# call and the others wait on the result. +_locks: dict[str, asyncio.Lock] = {} + + +def _lock_for(url: str) -> asyncio.Lock: + lock = _locks.get(url) + if lock is None: + lock = asyncio.Lock() + _locks[url] = lock + return lock + + +def clear_caches_for_tests() -> None: + """Test-only: drop everything.""" + _is_gated_cache.clear() + _file_size_cache.clear() + _locks.clear() + + +# --- public entrypoint --------------------------------------------------- # + + +async def probe_url(url: str) -> ProbeResult: + """Return downloadability + size for one URL, using caches where safe.""" + if not is_url_allowed(url): + return ProbeResult(file_size=None, is_hf_downloadable=None) + if not is_hf_url(url): + # Non-HF: ``is_hf_downloadable`` is "not applicable" (None). + # Size we still cache so we don't HEAD on every poll. + size = await _get_or_probe_size(url, token=None) + return ProbeResult(file_size=size, is_hf_downloadable=None) + + repo_id = repo_id_from_url(url) + if repo_id is None: + return ProbeResult(file_size=None, is_hf_downloadable=None) + + # Determine intrinsic gating once. + gated = await _resolve_is_gated(url, repo_id) + if gated is None: + return ProbeResult(file_size=None, is_hf_downloadable=None) + + # Compute current-token downloadability per call. + tok = await HF_AUTH_STORE.get_valid_token() + token_str: Optional[str] = tok.access_token if tok else None + if not gated: + is_hf_downloadable: Optional[bool] = True + else: + is_hf_downloadable = await _auth_check_with_token(repo_id, token_str) + + if is_hf_downloadable is True: + size = await _get_or_probe_size(url, token=token_str) + else: + # Skip the HEAD entirely — would 401 and we'd be stuck with + # cached None that survives a later login. + size = None + + return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable) + + +# --- gated/auth probes --------------------------------------------------- # + + +async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]: + """Decide once whether ``repo_id`` is a gated repo.""" + cached = _is_gated_cache.get(url) + if cached is not None: + return cached + + async with _lock_for(url): + cached = _is_gated_cache.get(url) + if cached is not None: + return cached + # Probe anonymously (token=None) on purpose: an unauthenticated + # auth_check is what makes HF raise GatedRepoError for gated repos. + # With a token, a gated-but-accepted repo would succeed and look + # ungated. + try: + await asyncio.to_thread(_auth_check_sync, repo_id, None) + _is_gated_cache[url] = False + return False + except GatedRepoError: + _is_gated_cache[url] = True + return True + except RepositoryNotFoundError: + # Repo doesn't exist publicly. Treat as gated — we can't + # serve it without auth, and an authenticated check might + # still succeed if it's a private repo the user can see. + _is_gated_cache[url] = True + return True + except (HfHubHTTPError, Exception) as e: + logging.debug( + "[hf_auth] is_gated probe failed for %s (will retry): %s", + repo_id, e, + ) + return None # don't cache; retry next call + + +async def _auth_check_with_token( + repo_id: str, token: Optional[str] +) -> Optional[bool]: + """Auth-check with the supplied token. True/False/None per outcome.""" + try: + await asyncio.to_thread(_auth_check_sync, repo_id, token) + return True + except GatedRepoError: + return False + except RepositoryNotFoundError: + return False + except HfHubHTTPError as e: + # 401/403 covers org-SSO-required, revoked tokens, and similar — + # all of which mean "can't fetch right now" from the user's POV. + status = getattr(getattr(e, "response", None), "status_code", None) + if status in (401, 403): + return False + logging.debug( + "[hf_auth] auth_check transient failure for %s: %s", repo_id, e, + ) + return None + except Exception as e: + logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e) + return None + + +def _auth_check_sync(repo_id: str, token: Optional[str]) -> None: + """Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``.""" + HfApi().auth_check(repo_id, token=token) + + +# --- size probe ---------------------------------------------------------- # + + +async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]: + """Return the cached size or HEAD the URL once and cache the result.""" + if url in _file_size_cache: + return _file_size_cache[url] + + async with _lock_for(url): + if url in _file_size_cache: + return _file_size_cache[url] + size = await _probe_size_once(url, token=token) + _file_size_cache[url] = size + return size + + +async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]: + """HEAD the URL and return the file size in bytes, or None on failure. + + HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL. + The real file size lives in the non-standard ``X-Linked-Size`` header + on that 302 response (``Content-Length`` is the redirect-body length). + Disabling redirect-follow lets us read either header on the same + response: + + - LFS files: 302 + ``X-Linked-Size`` + - Small/non-LFS files: 200 + ``Content-Length`` + """ + headers = {"Authorization": f"Bearer {token}"} if token else {} + try: + session = await get_session() + async with session.head( + url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers, + ) as resp: + linked = parse_content_length(resp.headers.get("X-Linked-Size")) + if linked is not None: + return linked + if resp.status == 200: + return parse_content_length(resp.headers.get("Content-Length")) + return None + except (aiohttp.ClientError, TimeoutError, OSError): + return None + + +# Backward-compat shim so consumers that still import the old name keep +# building during the refactor; can be removed once routes are updated. +MetadataProbeResult = ProbeResult diff --git a/app/model_downloader/hf_auth/auth_store.py b/app/model_downloader/hf_auth/auth_store.py new file mode 100644 index 000000000..069fc2d26 --- /dev/null +++ b/app/model_downloader/hf_auth/auth_store.py @@ -0,0 +1,121 @@ +"""In-memory token cache with lazy disk persistence + refresh. + +Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask +``get_valid_token()``; the store transparently refreshes from disk +on first use, refreshes via the OAuth refresh_token if the cached +access_token is expired, and returns ``None`` if neither path works. + +The refresh path imports ``oauth.refresh_access_token`` lazily to +avoid an import cycle (oauth needs the store to save tokens it +acquires). +""" + +from __future__ import annotations + +import logging +import threading +from typing import Optional + +from app.model_downloader.hf_auth.token_store import ( + Token, + delete_token, + load_token, + save_token, +) + + +class HfAuthStore: + def __init__(self) -> None: + self._lock = threading.Lock() + self._token: Optional[Token] = None + self._loaded_from_disk = False + + def _ensure_loaded(self) -> None: + """Read the disk token into memory on first access.""" + if self._loaded_from_disk: + return + with self._lock: + if self._loaded_from_disk: + return + self._token = load_token() + self._loaded_from_disk = True + + def has_token(self) -> bool: + """Cheap check: is there any token in memory? + + Does not attempt refresh; an expired-but-refreshable token still + counts as "logged in" from the user's perspective. + """ + self._ensure_loaded() + return self._token is not None + + def _store_token_locked(self, token: Token) -> None: + """Set the in-memory token and persist it to disk. + + Caller must already hold ``self._lock``. Keeping the disk write inside + the lock means memory and disk flip together — a concurrent ``clear()`` + or refresh can't interleave between them. + """ + self._token = token + self._loaded_from_disk = True + save_token(token) + + def set_token(self, token: Token) -> None: + """Replace the in-memory token and persist to disk (atomically).""" + with self._lock: + self._store_token_locked(token) + + def clear(self) -> None: + """Forget the token in memory and on disk (logout).""" + with self._lock: + self._token = None + self._loaded_from_disk = True + delete_token() + + def get_token_sync(self) -> Optional[Token]: + """Return the cached token without refreshing. + + Sync callers (e.g. constructing an Authorization header in a + non-async path) use this. They accept an expired token over + ``None``; HF will simply return 401 and the caller can decide + what to do. + """ + self._ensure_loaded() + return self._token + + async def get_valid_token(self) -> Optional[Token]: + """Return a fresh token, refreshing via OAuth if necessary. + + Returns ``None`` if there's no cached token at all, or if the + cached token is expired and refresh failed. Callers should + treat that as "user is not logged in". + """ + self._ensure_loaded() + with self._lock: + tok = self._token + if tok is None: + return None + if tok.is_valid(): + return tok + if not tok.refresh_token: + return None + + # Lazy import to avoid the oauth ↔ store import cycle. + from app.model_downloader.hf_auth.oauth import refresh_access_token + + try: + refreshed = await refresh_access_token(tok.refresh_token) + except Exception as e: + logging.warning("[hf_auth] token refresh failed: %s", e) + return None + + with self._lock: + # If a logout (clear) or another update replaced the token while we + # were awaiting the refresh, don't resurrect the old session. + if self._token is not tok: + return None + self._store_token_locked(refreshed) + return refreshed + + +HF_AUTH_STORE = HfAuthStore() diff --git a/app/model_downloader/hf_auth/eligibility.py b/app/model_downloader/hf_auth/eligibility.py new file mode 100644 index 000000000..ad788e4cd --- /dev/null +++ b/app/model_downloader/hf_auth/eligibility.py @@ -0,0 +1,55 @@ +"""Whether this deployment is allowed to do interactive HF OAuth. + +We only let the server hold a HuggingFace token under a strict trust +assumption: this is a *single tenant local* install. Concretely: + + - The server is bound to a loopback address. SSH tunneling / + reverse-proxies can defeat this, but it's the strongest signal + we have without an authentication system. + - ``--multi-user`` is off. A shared token used implicitly by multiple + declared users would be a footgun — one user's gated downloads + would silently authenticate as another. + +Anything else and the frontend hides the HF login UI entirely; gated +models continue to show the "acquire it manually" message. +""" + +from __future__ import annotations + +import ipaddress +import socket + +from comfy.cli_args import args + + +def _is_loopback(host: str | None) -> bool: + """Duplicates ``server.is_loopback`` (small, no shared module yet). + + Resolves a host or IP literal to whether it lives on the loopback + interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for + ``0.0.0.0`` / ``::`` because those are bind-all wildcards, not + loopback. + """ + if host is None: + return False + try: + return ipaddress.ip_address(host).is_loopback + except ValueError: + pass + + loopback = False + for family in (socket.AF_INET, socket.AF_INET6): + try: + r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM) + for _family, _, _, _, sockaddr in r: + if not ipaddress.ip_address(sockaddr[0]).is_loopback: + return loopback + loopback = True + except socket.gaierror: + pass + return loopback + + +def is_hf_auth_eligible() -> bool: + """True iff this deployment may surface the HF OAuth flow.""" + return _is_loopback(args.listen) and not args.multi_user diff --git a/app/model_downloader/hf_auth/oauth.py b/app/model_downloader/hf_auth/oauth.py new file mode 100644 index 000000000..b3527fb10 --- /dev/null +++ b/app/model_downloader/hf_auth/oauth.py @@ -0,0 +1,301 @@ +"""OAuth 2.0 PKCE flow against HuggingFace's authorization server. + +Wired so that ``POST /api/hf-auth-login-start`` can: + 1. Generate state + PKCE verifier/challenge in this process. + 2. Spin up a short-lived loopback HTTP server at port 41954 to + receive the redirect callback from HF. + 3. Return the ``authorize_url`` for the frontend to open in a new tab. + +After the user grants consent on huggingface.co, HF redirects to the +local callback URL with ``code`` and ``state``. The callback server +validates ``state`` (CSRF), exchanges the code for tokens via PKCE, +hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts +itself down. + +Before this can be exercised end-to-end a maintainer must register a +HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder +below. See the comment above the constant for the exact steps. +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import logging +import secrets +import threading +import time + +import aiohttp +from aiohttp import web + +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE +from app.model_downloader.hf_auth.token_store import Token +from app.model_downloader.http_client import get_session + + +# --- HF OAuth app registration -------------------------------------------- # +# NOTE: The OAuth client_id below is a placeholder. Before this feature can be +# exercised end-to-end, a maintainer must register a HuggingFace OAuth app +# under a Comfy-Org-controlled HF account and substitute its client_id here. +# Detailed walkthrough is in docs/server-side-model-downloads-handover.html +# ("HuggingFace OAuth app setup" section). Short version: +# 1. huggingface.co → Settings → Connected Apps → "Create app" +# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and +# ``gated-repos`` (Repository Access). Leave everything else off. +# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback`` +# — must match ``REDIRECT_URI`` below; change both in lockstep if you +# change ``CALLBACK_PORT``. +# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below. +# The client_id is not a secret (it travels through the user's browser in +# plaintext); HF's "Public app" type means there's no client secret to +# manage — PKCE replaces it. +HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID" + +CALLBACK_HOST = "127.0.0.1" +CALLBACK_PORT = 41954 +CALLBACK_PATH = "/api/auth/huggingface/callback" +REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}" + +AUTHORIZE_URL = "https://huggingface.co/oauth/authorize" +TOKEN_URL = "https://huggingface.co/oauth/token" +_TOKEN_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=30) +# Minimal scope set for the feature: +# - openid : required by HF when the app uses OIDC at all +# - profile : lets ``HfApi.whoami(token=...)`` return a username for the +# settings UI; cosmetic but expected +# - gated-repos : grants the token enough to call ``auth_check`` and +# download files from public gated repos the user has +# accepted the license for. The wider ``read-repos`` scope +# would also work (it includes ``gated-repos``) but it +# additionally grants private-repo read access, which we +# don't need and which makes the consent screen scarier +# for the user. +SCOPE = "openid profile gated-repos" + +# Maximum time the callback server stays up waiting for the user to +# complete consent on huggingface.co. Past this, the port closes and +# the user has to click "Log in" again. +CALLBACK_TIMEOUT_SECS = 300 + + +# Process-wide lock so two simultaneous /api/hf-auth-login-start +# requests don't fight over port CALLBACK_PORT. +_OAUTH_LOCK = threading.Lock() + + +class OAuthInProgressError(Exception): + """Another OAuth attempt is already running.""" + + +class OAuthCallbackError(Exception): + """The OAuth callback returned an error (HF denied, port stolen, etc.).""" + + +# --- PKCE primitives ------------------------------------------------------ # + + +def _make_pkce() -> tuple[str, str, str]: + """Return ``(verifier, challenge, state)``. + + Verifier never leaves this process. Challenge and state travel + through the user's browser. State is checked on the callback to + prevent a malicious cross-origin redirect from injecting a token. + """ + verifier = secrets.token_urlsafe(64) + challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) + .rstrip(b"=") + .decode("ascii") + ) + state = secrets.token_urlsafe(32) + return verifier, challenge, state + + +def _build_authorize_url(challenge: str, state: str) -> str: + from urllib.parse import urlencode + + params = { + "client_id": HF_CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "response_type": "code", + "scope": SCOPE, + "state": state, + "code_challenge": challenge, + "code_challenge_method": "S256", + } + return f"{AUTHORIZE_URL}?{urlencode(params)}" + + +# --- Token exchange ------------------------------------------------------- # + + +async def _exchange_code(code: str, verifier: str) -> Token: + """Trade the authorization code for an access+refresh token pair.""" + data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": REDIRECT_URI, + "client_id": HF_CLIENT_ID, + "code_verifier": verifier, + } + session = await get_session() + async with session.post(TOKEN_URL, data=data, timeout=_TOKEN_REQUEST_TIMEOUT) as resp: + resp.raise_for_status() + body = await resp.json() + return Token( + access_token=body["access_token"], + refresh_token=body.get("refresh_token"), + expires_at=time.time() + float(body.get("expires_in", 3600)), + scope=body.get("scope", SCOPE), + ) + + +async def refresh_access_token(refresh_token: str) -> Token: + """Trade a refresh_token for a new access (+ possibly refresh) token.""" + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": HF_CLIENT_ID, + } + session = await get_session() + async with session.post(TOKEN_URL, data=data, timeout=_TOKEN_REQUEST_TIMEOUT) as resp: + resp.raise_for_status() + body = await resp.json() + return Token( + access_token=body["access_token"], + # If HF doesn't rotate refresh tokens, keep using the existing one. + refresh_token=body.get("refresh_token", refresh_token), + expires_at=time.time() + float(body.get("expires_in", 3600)), + scope=body.get("scope", SCOPE), + ) + + +# --- Callback server ------------------------------------------------------ # + + +async def start_login_flow() -> str: + """Begin one OAuth attempt: spawn the callback server, return the URL. + + Returns the URL the frontend should open in a new tab. Raises + ``OAuthInProgressError`` if another attempt is already running. + The callback server runs in the background until the user + completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses; + either way the lock + port are released afterward. + """ + if not _OAUTH_LOCK.acquire(blocking=False): + raise OAuthInProgressError() + + try: + verifier, challenge, state = _make_pkce() + authorize_url = _build_authorize_url(challenge, state) + ready: asyncio.Future[None] = asyncio.get_event_loop().create_future() + except BaseException: + # Failed before handing the lock to the callback-server task: release it + # here. (Once the task is spawned, it owns releasing the lock.) + _OAUTH_LOCK.release() + raise + + asyncio.create_task(_run_callback_server(verifier, state, ready)) + # Don't return the URL until the callback server is actually bound and + # listening — otherwise HF could redirect to a port nothing is serving and + # the login would silently dead-end. ``ready`` raises if the bind failed. + await ready + return authorize_url + + +async def _run_callback_server( + verifier: str, expected_state: str, ready: "asyncio.Future[None]" +) -> None: + """Listen for HF's redirect once, capture the token, then shut down. + + Signals ``ready`` once the port is bound (or with an exception if the bind + fails), so ``start_login_flow`` only hands back a URL on a live server. + """ + received: asyncio.Future[Token] = asyncio.get_event_loop().create_future() + + async def handler(request: web.Request) -> web.Response: + try: + if request.query.get("state") != expected_state: + return web.Response(status=400, text="state mismatch") + err = request.query.get("error") + if err: + received.set_exception(OAuthCallbackError(f"HF returned: {err}")) + return web.Response(status=400, text=f"OAuth error: {err}") + code = request.query.get("code") + if not code: + return web.Response(status=400, text="missing code") + tok = await _exchange_code(code, verifier) + if not received.done(): + received.set_result(tok) + return web.Response( + content_type="text/html", + text=( + "" + "

HuggingFace login successful

" + "

You can close this tab and return to ComfyUI.

" + "" + ), + ) + except Exception as exc: + if not received.done(): + received.set_exception(exc) + return web.Response(status=500, text=str(exc)) + + app = web.Application() + app.router.add_get(CALLBACK_PATH, handler) + runner = web.AppRunner(app) + try: + await runner.setup() + site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True) + await site.start() + except Exception as e: + # Couldn't bind the callback port (commonly already in use). Tell the + # waiting start_login_flow via ``ready`` so it surfaces a clear error + # instead of returning a dead URL, and release the lock for next time. + logging.warning("[hf_auth] could not start callback server: %s", e) + if not ready.done(): + ready.set_exception( + OAuthCallbackError(f"could not bind callback port {CALLBACK_PORT}: {e}") + ) + _OAUTH_LOCK.release() + return + + # Bound and listening — now it's safe for start_login_flow to return the URL. + if not ready.done(): + ready.set_result(None) + + try: + token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS) + except asyncio.TimeoutError: + logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS) + return + except OAuthCallbackError as e: + logging.warning("[hf_auth] OAuth callback error: %s", e) + return + except Exception as e: + logging.warning("[hf_auth] unexpected OAuth failure: %s", e) + return + else: + HF_AUTH_STORE.set_token(token) + logging.info("[hf_auth] OAuth login complete") + finally: + await runner.cleanup() + if _OAUTH_LOCK.locked(): + _OAUTH_LOCK.release() + + +def is_login_in_progress() -> bool: + """True iff a callback server is currently bound + waiting.""" + return _OAUTH_LOCK.locked() + + +# Re-export for callers that only want the URL builder (e.g. tests). +__all__ = [ + "start_login_flow", + "refresh_access_token", + "is_login_in_progress", + "OAuthInProgressError", + "CALLBACK_TIMEOUT_SECS", +] diff --git a/app/model_downloader/hf_auth/token_store.py b/app/model_downloader/hf_auth/token_store.py new file mode 100644 index 000000000..7dfec656f --- /dev/null +++ b/app/model_downloader/hf_auth/token_store.py @@ -0,0 +1,94 @@ +"""On-disk persistence for the HuggingFace OAuth token. + +The token shape mirrors what HF returns on the token exchange: an +``access_token``, an optional ``refresh_token``, the absolute epoch at +which the access token expires, and the granted scope. We persist +this so logging in once survives ComfyUI restarts under the internal +``__hf_auth`` system-user directory; the file is mode ``0600`` so only +the OS user can read it. +""" + +from __future__ import annotations + +import json +import logging +import os +import stat +import time +from dataclasses import asdict, dataclass +from typing import Optional + +import folder_paths + + +# Treat a token as expired this many seconds before its server-reported +# ``expires_at`` so we don't try to use a token mid-request only for it +# to flip stale between auth_check and the actual GET. +EXPIRY_BUFFER_SECS = 60 + +TOKEN_FILENAME = "hf_auth_token.json" + + +@dataclass +class Token: + """One OAuth token + the metadata we need to use it.""" + access_token: str + refresh_token: Optional[str] + expires_at: float # absolute epoch seconds + scope: str = "" + + def is_valid(self) -> bool: + """True iff we can use this token right now.""" + return ( + bool(self.access_token) + and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS) + ) + + +def _token_dir() -> str: + return folder_paths.get_system_user_directory("hf_auth") + + +def _token_path() -> str: + return os.path.join(_token_dir(), TOKEN_FILENAME) + + +def load_token() -> Optional[Token]: + """Read the persisted token, returning ``None`` if absent or corrupt.""" + path = _token_path() + if not os.path.exists(path): + return None + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + return Token(**data) + except (OSError, json.JSONDecodeError, TypeError) as e: + logging.warning("[hf_auth] could not load token at %s: %s", path, e) + return None + + +def save_token(token: Token) -> None: + """Atomically write the token with 0600 permissions.""" + path = _token_path() + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = path + ".tmp" + fd = os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(asdict(token), f) + os.replace(tmp, path) + try: + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR) + except OSError as e: + # On Windows / weird filesystems chmod may be a no-op; not fatal. + logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e) + + +def delete_token() -> None: + """Remove the persisted token; no-op if it doesn't exist.""" + path = _token_path() + try: + os.remove(path) + except FileNotFoundError: + pass + except OSError as e: + logging.warning("[hf_auth] could not remove token at %s: %s", path, e) diff --git a/app/model_downloader/hf_url.py b/app/model_downloader/hf_url.py new file mode 100644 index 000000000..c305d5ac0 --- /dev/null +++ b/app/model_downloader/hf_url.py @@ -0,0 +1,41 @@ +"""Parsers for the ``huggingface.co`` URL shape we accept in workflows. + +The download API accepts URLs of the form +``https://huggingface.co///resolve//``. +We need to recover ``/`` (the *repo_id*) from such URLs for +``huggingface_hub`` API calls (notably ``HfApi.auth_check``). +""" + +from __future__ import annotations + +from typing import Optional +from urllib.parse import urlparse + +_HF_HOST = "huggingface.co" + + +def is_hf_url(url: str) -> bool: + """Cheap host check — does this URL point at huggingface.co?""" + try: + return urlparse(url).hostname == _HF_HOST + except ValueError: + return False + + +def repo_id_from_url(url: str) -> Optional[str]: + """Extract ``/`` from an HF model file URL. + + Returns ``None`` if the URL isn't on huggingface.co or doesn't look + like a model-file URL. The expected shape is + ``///resolve//`` — anything else + (datasets, spaces, /tree/, /blob/, …) we treat as out of scope here. + """ + if not is_hf_url(url): + return None + parts = urlparse(url).path.lstrip("/").split("/") + if len(parts) < 4 or parts[2] != "resolve": + return None + org, repo = parts[0], parts[1] + if not org or not repo: + return None + return f"{org}/{repo}" diff --git a/app/model_downloader/http_client.py b/app/model_downloader/http_client.py new file mode 100644 index 000000000..4c2b81dc6 --- /dev/null +++ b/app/model_downloader/http_client.py @@ -0,0 +1,63 @@ +"""Lazy module-level aiohttp ClientSession. + +A single shared session means TLS handshakes are reused across HEAD probes +and the subsequent GETs to the same host (HuggingFace is the dominant +case), which is a noticeable speedup on cold connections. + +We deliberately don't close the session at process exit — aiohttp's +warning about unclosed sessions is benign at shutdown, and adding atexit +plumbing buys nothing because the OS reclaims the sockets anyway. The +session lifetime is the lifetime of the Python process. +""" + +from __future__ import annotations + +import asyncio +import ssl +from typing import Optional + +import aiohttp +import certifi + + +# Larger per-host pool than aiohttp's default (=100 total / =0 per host) +# so concurrent gated probes + a download to the same host don't queue. +_CONNECTOR_LIMIT_PER_HOST = 8 + +_session: Optional[aiohttp.ClientSession] = None +_lock = asyncio.Lock() + + +def ssl_context() -> ssl.SSLContext: + """TLS context pinned to certifi's CA bundle. + aiohttp's default context uses the OS trust store, which isn't wired up + on some Python installs (python.org macOS, slim containers) — there TLS + to huggingface.co fails with CERTIFICATE_VERIFY_FAILED. + """ + return ssl.create_default_context(cafile=certifi.where()) + + +async def get_session() -> aiohttp.ClientSession: + """Return the shared session, creating it on first call.""" + global _session + if _session is not None and not _session.closed: + return _session + async with _lock: + if _session is None or _session.closed: + connector = aiohttp.TCPConnector( + limit_per_host=_CONNECTOR_LIMIT_PER_HOST, + ssl=ssl_context(), + ) + _session = aiohttp.ClientSession(connector=connector) + return _session + + +def parse_content_length(value: Optional[str]) -> Optional[int]: + """Parse a byte-count header value, or None if absent/malformed/negative.""" + if not value: + return None + try: + n = int(value) + except ValueError: + return None + return n if n >= 0 else None diff --git a/app/model_downloader/paths.py b/app/model_downloader/paths.py new file mode 100644 index 000000000..002f7cb80 --- /dev/null +++ b/app/model_downloader/paths.py @@ -0,0 +1,111 @@ +"""Path resolution for model downloads. + +Model identifiers used across the download API are *relative destination +paths* of the form ``/`` (e.g. ``loras/my_lora.safetensors``). +This module turns one of those identifiers into an absolute on-disk path +under one of ComfyUI's registered model folders, while rejecting unknown +folders, path traversal, and other ill-formed inputs. +""" + +import os +import re +from typing import Optional, Tuple + +import folder_paths + + +# Constrain components so a model_id can never escape its target directory. +# - directory: a single path segment of safe chars +# - filename: a single path segment of safe chars, must end with a model extension +_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$") + +# Destination filename must name a model file (same set as the URL allowlist), +# so a download can't land as e.g. ``foo.txt`` that ComfyUI won't recognise. +_MODEL_EXTENSIONS = (".safetensors", ".sft", ".ckpt", ".pth", ".pt") + +# Distinctive temp suffix so the startup orphan-sweep only removes files THIS +# subsystem created — never unrelated ``*.tmp`` files in the model dirs. +_TMP_SUFFIX = ".comfy-download.tmp" + + +class InvalidModelId(ValueError): + """Raised when a model_id is syntactically invalid or refers to an + unknown model folder.""" + + +def parse_model_id(model_id: str) -> Tuple[str, str]: + """Split ``/`` and validate both components. + + Returns ``(directory, filename)``. Raises ``InvalidModelId`` on + malformed input. Does NOT touch the filesystem. + """ + if not isinstance(model_id, str) or "/" not in model_id: + raise InvalidModelId(f"model_id must be '/', got {model_id!r}") + directory, _, filename = model_id.partition("/") + if "/" in filename or not directory or not filename: + raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}") + if not _SEGMENT_RE.match(directory): + raise InvalidModelId(f"invalid directory segment {directory!r}") + if not _SEGMENT_RE.match(filename): + raise InvalidModelId(f"invalid filename segment {filename!r}") + if not filename.endswith(_MODEL_EXTENSIONS): + raise InvalidModelId( + f"filename must end with a model extension {_MODEL_EXTENSIONS}, got {filename!r}" + ) + if directory not in folder_paths.folder_names_and_paths: + raise InvalidModelId(f"unknown model folder {directory!r}") + return directory, filename + + +def resolve_existing(model_id: str) -> Optional[str]: + """Return the absolute path of an installed model, or None if missing. + + Honours ``extra_model_paths.yaml`` transparently via + ``folder_paths.get_full_path``. + """ + directory, filename = parse_model_id(model_id) + return folder_paths.get_full_path(directory, filename) + + +def resolve_destination(model_id: str, epoch: int = 0) -> Tuple[str, str]: + """Return ``(final_path, tmp_path)`` for a download. + + Downloads land at the first registered path for the model's directory + (the "primary" location). The temp sibling is the write target, atomically + renamed onto ``final_path`` on success. + + ``tmp_path`` embeds the session ``epoch`` so a cancel+retry of the same + model never shares a temp path between the old (cancelling) worker and the + new attempt — otherwise the old worker's rollback could delete the new + worker's in-progress file. The distinctive suffix scopes the orphan sweep. + """ + directory, filename = parse_model_id(model_id) + roots = folder_paths.get_folder_paths(directory) + if not roots: + raise InvalidModelId(f"no on-disk path registered for folder {directory!r}") + root = roots[0] + final_path = os.path.join(root, filename) + tmp_path = f"{final_path}.{epoch}{_TMP_SUFFIX}" + return final_path, tmp_path + + +def iter_all_tmp_paths(): + """Yield this subsystem's temp files under every registered model folder. + + Matches only our distinctive ``_TMP_SUFFIX`` (not every ``*.tmp``) so the + startup orphan-sweep can't delete temp files created by other tools. + """ + seen_roots: set[str] = set() + for directory in folder_paths.folder_names_and_paths.keys(): + for root in folder_paths.get_folder_paths(directory): + if root in seen_roots or not os.path.isdir(root): + continue + seen_roots.add(root) + try: + for entry in os.scandir(root): + if entry.is_file() and entry.name.endswith(_TMP_SUFFIX): + yield entry.path + except OSError: + # Folder might be unreadable / missing on certain mounts — + # not fatal, just skip it. + continue diff --git a/tests-unit/app_test/hf_auth_test.py b/tests-unit/app_test/hf_auth_test.py new file mode 100644 index 000000000..a6784dbf6 --- /dev/null +++ b/tests-unit/app_test/hf_auth_test.py @@ -0,0 +1,708 @@ +"""Unit tests for the HuggingFace auth subsystem. + +Covers: + - token store: save/load roundtrip, chmod 0600, atomic write, delete + - eligibility under various CLI-arg combinations + - URL parsing (huggingface.co host detection + repo_id extraction) + - HF-aware gated_detection.probe_url (mocked auth_check) + - HF auth routes (token status, login start with eligibility gate, logout) + - PKCE primitives + authorize URL shape + +The OAuth callback server itself isn't exercised end-to-end here — that +requires a real HF server. We test the components (state checking, +URL building, code-exchange request shape) instead. +""" + +from __future__ import annotations + +import os +import stat +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp import web + +from app.model_downloader.api.routes import register_routes +from app.model_downloader.hf_auth import oauth +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore +from app.model_downloader.hf_auth.token_store import ( + EXPIRY_BUFFER_SECS, + Token, + delete_token, + load_token, + save_token, +) +from app.model_downloader.hf_url import is_hf_url, repo_id_from_url + + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def patched_user_dir(tmp_path): + """Redirect ``folder_paths.get_user_directory`` so the token file + lands in an isolated tmp_path instead of the real user dir.""" + user_dir = tmp_path / "user" + user_dir.mkdir() + with patch("folder_paths.get_user_directory", return_value=str(user_dir)): + yield user_dir + + +def _token_file_path(user_dir) -> str: + return os.path.join(user_dir, "__hf_auth", "hf_auth_token.json") + + +@pytest.fixture +def fresh_auth_store(): + """Wipe singleton state between tests: auth + probe caches.""" + from app.model_downloader import gated_detection + + HF_AUTH_STORE._token = None + HF_AUTH_STORE._loaded_from_disk = False + gated_detection.clear_caches_for_tests() + yield HF_AUTH_STORE + HF_AUTH_STORE._token = None + HF_AUTH_STORE._loaded_from_disk = False + gated_detection.clear_caches_for_tests() + + +@pytest.fixture +def app(patched_user_dir, fresh_auth_store): + app = web.Application() + register_routes(app) + return app + + +# --------------------------------------------------------------------------- # +# URL parsing +# --------------------------------------------------------------------------- # + + +def test_is_hf_url_recognises_huggingface_co(): + assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors") + assert is_hf_url("https://huggingface.co/abc") + assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors") + assert not is_hf_url("https://civitai.com/x.safetensors") + + +def test_repo_id_from_url_extracts_org_and_repo(): + url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors" + assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR" + + +def test_repo_id_from_url_handles_nested_path(): + url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors" + assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3" + + +def test_repo_id_from_url_returns_none_for_non_hf(): + assert repo_id_from_url("https://civitai.com/x.safetensors") is None + + +def test_repo_id_from_url_returns_none_for_non_resolve_paths(): + assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None + assert repo_id_from_url("https://huggingface.co/org") is None + + +# --------------------------------------------------------------------------- # +# Token store +# --------------------------------------------------------------------------- # + + +def test_token_store_roundtrip(patched_user_dir): + tok = Token( + access_token="hf_abc", + refresh_token="rf_def", + expires_at=9999999999.0, + scope="openid profile", + ) + save_token(tok) + loaded = load_token() + assert loaded == tok + + +def test_token_store_writes_0600(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=0.0) + save_token(tok) + path = _token_file_path(patched_user_dir) + mode = stat.S_IMODE(os.stat(path).st_mode) + # On Windows we silently no-op chmod; allow either the intended + # mode or whatever umask the OS gave us. + if os.name == "posix": + assert mode == 0o600 + + +def test_token_store_delete_removes_file(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=0.0) + save_token(tok) + delete_token() + path = _token_file_path(patched_user_dir) + assert not os.path.exists(path) + # Idempotent: second delete is fine. + delete_token() + + +def test_token_store_load_returns_none_for_missing_file(patched_user_dir): + assert load_token() is None + + +def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir): + path = _token_file_path(patched_user_dir) + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "w") as f: + f.write("not json {") + assert load_token() is None + + +def test_token_is_valid_uses_buffer(patched_user_dir): + import time + + fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600) + nearly_expired = Token( + access_token="x", + refresh_token=None, + expires_at=time.time() + EXPIRY_BUFFER_SECS - 1, + ) + assert fresh.is_valid() + assert not nearly_expired.is_valid() + + +def test_token_is_valid_rejects_empty_access_token(): + import time + + tok = Token(access_token="", refresh_token=None, expires_at=time.time() + 3600) + assert not tok.is_valid() + + +def test_token_is_valid_rejects_at_exact_buffer_boundary(): + import time + + tok = Token( + access_token="x", + refresh_token=None, + expires_at=time.time() + EXPIRY_BUFFER_SECS, + ) + assert not tok.is_valid() + + +# --------------------------------------------------------------------------- # +# Auth store +# --------------------------------------------------------------------------- # + + +def test_auth_store_loads_lazily(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + save_token(tok) + store = HfAuthStore() + assert store.has_token() + assert store.get_token_sync() == tok + + +def test_auth_store_set_persists(patched_user_dir): + store = HfAuthStore() + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + store.set_token(tok) + # Token is on disk now — a fresh store sees it. + assert HfAuthStore().get_token_sync() == tok + + +def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir): + store = HfAuthStore() + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + store.set_token(tok) + store.clear() + assert not store.has_token() + assert HfAuthStore().get_token_sync() is None + + +def test_auth_store_has_token_true_when_expired_but_refreshable(patched_user_dir): + import time + + store = HfAuthStore() + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + assert store.has_token() + assert not expired.is_valid() + + +def test_auth_store_get_token_sync_returns_expired_without_refresh(patched_user_dir): + import time + + store = HfAuthStore() + expired = Token( + access_token="old", + refresh_token=None, + expires_at=time.time() - 100, + ) + store.set_token(expired) + assert store.get_token_sync() == expired + + +@pytest.mark.asyncio +async def test_auth_store_get_valid_returns_none_when_expired_without_refresh( + patched_user_dir, +): + import time + + store = HfAuthStore() + expired = Token( + access_token="old", + refresh_token=None, + expires_at=time.time() - 100, + ) + store.set_token(expired) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(), + ) as refresh_mock: + result = await store.get_valid_token() + assert result is None + refresh_mock.assert_not_called() + + +@pytest.mark.asyncio +async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir): + store = HfAuthStore() + import time + + tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600) + store.set_token(tok) + fetched = await store.get_valid_token() + assert fetched == tok + + +@pytest.mark.asyncio +async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir): + store = HfAuthStore() + import time + + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + refreshed = Token( + access_token="new", + refresh_token="rf", + expires_at=time.time() + 3600, + ) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(return_value=refreshed), + ): + result = await store.get_valid_token() + assert result == refreshed + + +@pytest.mark.asyncio +async def test_auth_store_get_valid_token_does_not_resurrect_after_logout( + patched_user_dir, +): + """A logout landing *during* an in-flight refresh must not be undone by + the refresh writing the token back (the resurrection race).""" + store = HfAuthStore() + import time + + expired = Token( + access_token="old", refresh_token="rf", expires_at=time.time() - 100 + ) + store.set_token(expired) + refreshed = Token( + access_token="new", refresh_token="rf", expires_at=time.time() + 3600 + ) + + async def fake_refresh(_refresh_token): + # Simulate the user clicking "Log out" while the refresh is in flight. + store.clear() + return refreshed + + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=fake_refresh, + ): + result = await store.get_valid_token() + + # The refresh result is discarded — logout wins, in memory and on disk. + assert result is None + assert not store.has_token() + assert load_token() is None + + +@pytest.mark.asyncio +async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir): + store = HfAuthStore() + import time + + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(side_effect=RuntimeError("HF down")), + ): + result = await store.get_valid_token() + assert result is None + + +# --------------------------------------------------------------------------- # +# Eligibility +# --------------------------------------------------------------------------- # + + +@pytest.mark.parametrize( + "listen,multi_user,expected", + [ + ("127.0.0.1", False, True), + ("127.0.0.1", True, False), # multi-user disables it + ("0.0.0.0", False, False), # bind-all is not loopback + ("0.0.0.0", True, False), + ("192.168.1.5", False, False), # LAN address + ("::1", False, True), # IPv6 loopback + ], +) +def test_eligibility(listen, multi_user, expected, monkeypatch): + from app.model_downloader.hf_auth import eligibility + from comfy.cli_args import args + + monkeypatch.setattr(args, "listen", listen) + monkeypatch.setattr(args, "multi_user", multi_user) + assert eligibility.is_hf_auth_eligible() is expected + + +# --------------------------------------------------------------------------- # +# gated_detection HF probe +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_probe_url_hf_public(fresh_auth_store): + """auth_check succeeds with no token → is_hf_downloadable = True.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch("app.model_downloader.gated_detection._auth_check_sync"), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=1024), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is True + assert result.file_size == 1024 + + +@pytest.mark.asyncio +async def test_probe_url_hf_gated_no_access(fresh_auth_store): + """auth_check raises GatedRepoError → is_hf_downloadable = False.""" + from huggingface_hub.errors import GatedRepoError + + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors" + fake_response = MagicMock(status_code=403) + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=GatedRepoError("gated", response=fake_response), + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is False + + +@pytest.mark.asyncio +async def test_probe_url_non_hf_skips_auth_check(): + """Non-HF URLs never call auth_check; is_hf_downloadable stays None.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://civitai.com/api/download/models/1.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=2048), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is None + assert result.file_size == 2048 + mocked.assert_not_called() + + +@pytest.mark.asyncio +async def test_is_gated_cached_across_calls(fresh_auth_store): + """Intrinsic ``is_gated`` should be determined exactly once per URL. + + Subsequent ``probe_url`` calls for the same URL must not re-issue + the null-token auth_check — that's the whole point of the cache.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync" + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=1024), + ): + await probe_url(url) + await probe_url(url) + await probe_url(url) + # Three probe_url calls × public-only-needs-1-auth_check = 1 call total. + assert mocked.call_count == 1 + + +@pytest.mark.asyncio +async def test_file_size_cached_across_calls(fresh_auth_store): + """Once a successful HEAD lands, subsequent calls don't re-HEAD.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync" + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=2048), + ) as size_probe: + r1 = await probe_url(url) + r2 = await probe_url(url) + assert r1.file_size == 2048 + assert r2.file_size == 2048 + assert size_probe.call_count == 1 + + +@pytest.mark.asyncio +async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store): + """When ``is_hf_downloadable`` is False we must NOT HEAD the URL — + otherwise a 401-due-to-gating would land as a cached ``None`` that + survives a later successful login.""" + from app.model_downloader.gated_detection import probe_url + from huggingface_hub.errors import GatedRepoError + from unittest.mock import MagicMock + + url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors" + fake_resp = MagicMock(status_code=403) + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=GatedRepoError("gated", response=fake_resp), + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ) as size_probe: + result = await probe_url(url) + assert result.is_hf_downloadable is False + assert result.file_size is None + assert size_probe.call_count == 0 + + +@pytest.mark.asyncio +async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir): + """For a gated URL, auth_check runs twice: once with token=None to + determine the intrinsic ``is_gated`` flag (cached forever), and once + with the stored access_token to determine ``is_hf_downloadable`` for + the current user.""" + from app.model_downloader import gated_detection + from app.model_downloader.gated_detection import probe_url + from huggingface_hub.errors import GatedRepoError + from unittest.mock import MagicMock + + gated_detection.clear_caches_for_tests() + fresh_auth_store.set_token(Token( + access_token="hf_test_token", + refresh_token=None, + expires_at=9999999999.0, + )) + url = "https://huggingface.co/private/repo/resolve/main/x.safetensors" + + fake_resp = MagicMock(status_code=403) + + def fake_auth_check(repo_id, token): + # Null-token call → repo is gated. Subsequent call with the real + # token succeeds (user has access). + if token is None: + raise GatedRepoError("gated", response=fake_resp) + + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=fake_auth_check, + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ): + result = await probe_url(url) + + # is_hf_downloadable should be True (token-authed call succeeded). + assert result.is_hf_downloadable is True + # Two calls: (repo_id, None) then (repo_id, ). + assert mocked.call_count == 2 + assert mocked.call_args_list[0].args == ("private/repo", None) + assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token") + + +# --------------------------------------------------------------------------- # +# OAuth primitives +# --------------------------------------------------------------------------- # + + +def test_make_pkce_returns_distinct_high_entropy_values(): + verifier1, challenge1, state1 = oauth._make_pkce() + verifier2, challenge2, state2 = oauth._make_pkce() + assert verifier1 != verifier2 + assert challenge1 != challenge2 + assert state1 != state2 + # Verifier should be at least 43 chars per PKCE spec. + assert len(verifier1) >= 43 + + +def test_build_authorize_url_includes_pkce_and_state(): + url = oauth._build_authorize_url("challenge123", "state456") + assert url.startswith(oauth.AUTHORIZE_URL) + assert "client_id=" + oauth.HF_CLIENT_ID in url + assert "code_challenge=challenge123" in url + assert "code_challenge_method=S256" in url + assert "state=state456" in url + assert "response_type=code" in url + + +# --------------------------------------------------------------------------- # +# Routes +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_hf_auth_token_status_empty(aiohttp_client, app): + """No token set → token_available=false, username=null.""" + client = await aiohttp_client(app) + resp = await client.get("/api/hf-auth-token-status") + assert resp.status == 200 + data = await resp.json() + assert data == {"token_available": False, "username": None} + + +@pytest.mark.asyncio +async def test_hf_auth_token_status_with_token( + aiohttp_client, app, fresh_auth_store, patched_user_dir +): + """Token present, whoami works → username is returned.""" + fresh_auth_store.set_token(Token( + access_token="x", refresh_token=None, expires_at=9999999999.0, + )) + with patch( + "app.model_downloader.api.routes._whoami_username", + return_value="alice", + ): + client = await aiohttp_client(app) + resp = await client.get("/api/hf-auth-token-status") + assert resp.status == 200 + assert (await resp.json()) == {"token_available": True, "username": "alice"} + + +@pytest.mark.asyncio +async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch): + """Not loopback / multi-user → 403.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: False, + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 403 + assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE" + + +@pytest.mark.asyncio +async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch): + """Eligible + first attempt → 200 with authorize_url.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 200 + assert (await resp.json())["authorize_url"].startswith( + "https://huggingface.co/oauth/authorize" + ) + + +@pytest.mark.asyncio +async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch): + """Lock already held → 409.""" + from app.model_downloader.hf_auth.oauth import OAuthInProgressError + + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(side_effect=OAuthInProgressError()), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS" + + +@pytest.mark.asyncio +async def test_hf_auth_login_start_503_when_callback_bind_fails( + aiohttp_client, app, monkeypatch +): + """Callback server failed to bind (e.g. port busy) → 503, not a dead URL.""" + from app.model_downloader.hf_auth.oauth import OAuthCallbackError + + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(side_effect=OAuthCallbackError("could not bind callback port")), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 503 + assert (await resp.json())["error"]["code"] == "HF_AUTH_START_FAILED" + + +@pytest.mark.asyncio +async def test_hf_auth_logout_clears_store( + aiohttp_client, app, fresh_auth_store, patched_user_dir +): + fresh_auth_store.set_token(Token( + access_token="x", refresh_token=None, expires_at=9999999999.0, + )) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-logout") + assert resp.status == 200 + assert (await resp.json()) == {"logged_out": True} + assert not fresh_auth_store.has_token() + + +@pytest.mark.asyncio +async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch): + """The availability response embeds {token_available, eligible}.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/models-availability-status", + json={"models": {}}, + ) + assert resp.status == 200 + data = await resp.json() + assert "hf_auth" in data + assert data["hf_auth"] == {"token_available": False, "eligible": True} diff --git a/tests-unit/app_test/model_downloader_test.py b/tests-unit/app_test/model_downloader_test.py new file mode 100644 index 000000000..62a8d5190 --- /dev/null +++ b/tests-unit/app_test/model_downloader_test.py @@ -0,0 +1,514 @@ +"""Unit tests for the server-side model download subsystem. + +Covers the pieces that don't require talking to a real network: + + - path parsing & allowlist (pure functions) + - DownloadServer registry lifecycle (in-memory state) + - API routes via aiohttp_client + folder_paths/probe_url patches + +Streaming downloads themselves are exercised indirectly — the route-level +tests stub out the network probe so we can verify the gating logic in +``download_models`` without making real HTTP calls. +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import patch, AsyncMock + +import pytest +from aiohttp import web + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.api.routes import register_routes +from app.model_downloader.download_server import DownloadServer +from app.model_downloader.gated_detection import MetadataProbeResult +from app.model_downloader.paths import ( + InvalidModelId, + parse_model_id, + resolve_destination, + resolve_existing, +) + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def model_root(tmp_path): + """A fake ``models/`` root with two registered folder types.""" + loras_dir = tmp_path / "loras" + checkpoints_dir = tmp_path / "checkpoints" + loras_dir.mkdir() + checkpoints_dir.mkdir() + return tmp_path, loras_dir, checkpoints_dir + + +@pytest.fixture +def patched_folder_paths(model_root): + """Point folder_paths at our fake roots for the duration of one test.""" + _root, loras_dir, checkpoints_dir = model_root + mapping = { + "loras": ([str(loras_dir)], {".safetensors"}), + "checkpoints": ([str(checkpoints_dir)], {".safetensors"}), + } + with patch( + "folder_paths.folder_names_and_paths", mapping + ), patch( + "folder_paths.get_folder_paths", + side_effect=lambda name: mapping.get(name, ([], set()))[0], + ): + yield mapping + + +@pytest.fixture +def fresh_download_server(): + """Reset the module-level singleton between tests so registry state + doesn't leak across tests sharing the singleton.""" + from app.model_downloader.download_server import DOWNLOAD_SERVER + + DOWNLOAD_SERVER.reset_for_tests() + yield DOWNLOAD_SERVER + DOWNLOAD_SERVER.reset_for_tests() + + +@pytest.fixture +def app(patched_folder_paths, fresh_download_server): + app = web.Application() + register_routes(app) + return app + + +# --------------------------------------------------------------------------- # +# Pure helpers: allowlist + path parsing +# --------------------------------------------------------------------------- # + + +def test_allowlist_accepts_hf_safetensors(): + assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors") + + +def test_allowlist_accepts_civitai_pth(): + assert is_url_allowed("https://civitai.com/api/download/models/123.pth") + + +def test_allowlist_rejects_unknown_host(): + assert not is_url_allowed("https://example.com/x.safetensors") + + +def test_allowlist_rejects_api_path_on_hf(): + # On an allowlisted host but not pointing at a model file. + assert not is_url_allowed("https://huggingface.co/api/models") + + +def test_allowlist_rejects_non_https_except_localhost(): + assert not is_url_allowed("http://huggingface.co/x/y.safetensors") + assert is_url_allowed("http://localhost:8000/x.safetensors") + + +def test_parse_model_id_valid(patched_folder_paths): + assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors") + + +def test_parse_model_id_rejects_traversal(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("../etc/passwd") + + +def test_parse_model_id_rejects_unknown_folder(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("nope/x.safetensors") + + +def test_parse_model_id_rejects_double_slash(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("loras/sub/x.safetensors") + + +def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths): + _root, loras_dir, _ = model_root + target = loras_dir / "foo.safetensors" + target.write_bytes(b"x") + assert resolve_existing("loras/foo.safetensors") == str(target) + + +def test_resolve_existing_returns_none_when_absent(patched_folder_paths): + assert resolve_existing("loras/missing.safetensors") is None + + +def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths): + _root, loras_dir, _ = model_root + final, tmp = resolve_destination("loras/foo.safetensors", epoch=7) + assert final == str(loras_dir / "foo.safetensors") + # Temp path embeds the session epoch (so cancel+retry can't collide on it) + # and uses the subsystem-specific suffix the startup sweep matches. + assert tmp == f"{final}.7.comfy-download.tmp" + + +# --------------------------------------------------------------------------- # +# DownloadServer registry: lifecycle, races, cancellation epoch semantics +# --------------------------------------------------------------------------- # + + +def test_register_is_exclusive(): + server = DownloadServer() + s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a") + s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b") + assert s1 is not None + assert s2 is None + assert server.is_downloading("loras/x.safetensors") + + +def test_cancel_removes_session(): + server = DownloadServer() + server.try_register("loras/x.safetensors", "https://huggingface.co/a") + assert server.cancel("loras/x.safetensors") is True + assert not server.is_downloading("loras/x.safetensors") + + +def test_cancel_returns_false_when_absent(): + server = DownloadServer() + assert server.cancel("loras/never.safetensors") is False + + +def test_finish_only_clears_matching_epoch(): + """If a session is cancelled and a new one for the same id is + registered, ``finish`` from the original worker must not evict the + newer session.""" + server = DownloadServer() + s_old = server.try_register("loras/x.safetensors", "u1") + server.cancel("loras/x.safetensors") + s_new = server.try_register("loras/x.safetensors", "u2") + assert s_new is not None and s_new.epoch != s_old.epoch + # Old worker's late finish() is a no-op: + server.finish(s_old) + assert server.is_downloading("loras/x.safetensors") + server.finish(s_new) + assert not server.is_downloading("loras/x.safetensors") + + +def test_is_active_follows_cancellation(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + assert server.is_active(s) + server.cancel("loras/x.safetensors") + assert not server.is_active(s) + + +def test_update_progress_tracks_fraction(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + server.update_progress(s, 50, 100) + snap = server.snapshot()["loras/x.safetensors"] + assert snap.bytes_downloaded == 50 + assert snap.total_bytes == 100 + assert snap.progress == 0.5 + + +def test_update_progress_with_unknown_total_keeps_progress_none(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + server.update_progress(s, 50, None) + assert server.snapshot()["loras/x.safetensors"].progress is None + + +def test_cleanup_orphan_tmp_files(model_root): + """Orphan temp left by a crashed download must be swept on first use, + while unrelated *.tmp files in the model dir are left untouched.""" + _root, loras_dir, _ = model_root + orphan = loras_dir / "stale.safetensors.3.comfy-download.tmp" + orphan.write_bytes(b"partial") + unrelated = loras_dir / "someothertool.tmp" + unrelated.write_bytes(b"not ours") + mapping = {"loras": ([str(loras_dir)], {".safetensors"})} + with patch("folder_paths.folder_names_and_paths", mapping), patch( + "folder_paths.get_folder_paths", + side_effect=lambda name: mapping.get(name, ([], set()))[0], + ): + server = DownloadServer() + assert orphan.exists(), "sweep must not run at construction time" + server.sweep_orphan_tmp_files() + assert not orphan.exists() + assert unrelated.exists(), "unrelated .tmp must not be swept" + # Idempotent — a second call is a cheap no-op. + server.sweep_orphan_tmp_files() + + +# --------------------------------------------------------------------------- # +# Route: POST /api/models-availability-status +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_availability_partitions_correctly( + aiohttp_client, app, model_root, fresh_download_server +): + _root, loras_dir, _ = model_root + (loras_dir / "present.safetensors").write_bytes(b"x") + fresh_download_server.try_register( + "loras/inflight.safetensors", "http://localhost:8000/x.safetensors" + ) + client = await aiohttp_client(app) + + # Stub probes — we're testing state assignment, not network calls. + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult( + file_size=None, is_hf_downloadable=None, + )), + ): + body = { + "models": { + "loras/present.safetensors": "http://localhost:8000/p.safetensors", + "loras/missing.safetensors": "http://localhost:8000/m.safetensors", + "loras/inflight.safetensors": "http://localhost:8000/x.safetensors", + } + } + resp = await client.post("/api/models-availability-status", json=body) + assert resp.status == 200 + data = await resp.json() + models = data["models"] + assert models["loras/present.safetensors"]["state"] == "available" + assert models["loras/missing.safetensors"]["state"] == "missing" + assert models["loras/inflight.safetensors"]["state"] == "downloading" + assert "hf_auth" in data + + +@pytest.mark.asyncio +async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app): + client = await aiohttp_client(app) + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult( + file_size=None, is_hf_downloadable=None, + )), + ): + resp = await client.post( + "/api/models-availability-status", + json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["models"]["../etc/passwd"]["state"] == "missing" + + +# --------------------------------------------------------------------------- # +# Route: POST /api/download-models — precondition gating +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}}, + ) + assert resp.status == 400 + err = (await resp.json())["error"] + assert err["code"] == "URL_NOT_ALLOWED" + + +@pytest.mark.asyncio +async def test_download_rejects_already_available( + aiohttp_client, app, model_root +): + _root, loras_dir, _ = model_root + (loras_dir / "x.safetensors").write_bytes(b"x") + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE" + + +@pytest.mark.asyncio +async def test_download_rejects_already_downloading( + aiohttp_client, app, fresh_download_server +): + fresh_download_server.try_register( + "loras/x.safetensors", "https://huggingface.co/u.safetensors" + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING" + + +@pytest.mark.asyncio +async def test_download_rejects_gated_model(aiohttp_client, app): + client = await aiohttp_client(app) + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)), + ): + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 400 + assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE" + + +@pytest.mark.asyncio +async def test_download_rejects_invalid_model_id(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}}, + ) + assert resp.status == 400 + assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID" + + +@pytest.mark.asyncio +async def test_download_atomic_failure_does_not_register_partial( + aiohttp_client, app, model_root, fresh_download_server +): + """If one model in a batch fails, none get registered.""" + _root, loras_dir, _ = model_root + (loras_dir / "already.safetensors").write_bytes(b"x") + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={ + "models": { + "loras/already.safetensors": + "https://huggingface.co/a/b/resolve/main/already.safetensors", + "loras/new.safetensors": + "https://huggingface.co/a/b/resolve/main/new.safetensors", + } + }, + ) + assert resp.status == 409 + # The "new" model should not have been registered as part of the + # failed batch. + assert not fresh_download_server.is_downloading("loras/new.safetensors") + + +@pytest.mark.asyncio +async def test_download_schedules_when_all_preconditions_pass( + aiohttp_client, app, fresh_download_server +): + """Verify the precondition pass, registration pass, and async + scheduling all wire up correctly. We patch the streamer to avoid + real HTTP while still letting the route execute end-to-end.""" + started = asyncio.Event() + finish_signal = asyncio.Event() + + async def fake_stream(session): + started.set() + await finish_signal.wait() + from app.model_downloader.download_server import DOWNLOAD_SERVER + DOWNLOAD_SERVER.finish(session) + return "/dev/null" + + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)), + ), patch( + "app.model_downloader.downloader.stream_to_disk", new=fake_stream + ): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/new.safetensors": + "https://huggingface.co/a/b/resolve/main/new.safetensors" + }}, + ) + assert resp.status == 202 + body = await resp.json() + assert body["accepted"] is True + assert body["scheduled"] == ["loras/new.safetensors"] + # Wait for the worker to actually start. + await asyncio.wait_for(started.wait(), timeout=2.0) + assert fresh_download_server.is_downloading("loras/new.safetensors") + finish_signal.set() + + +# --------------------------------------------------------------------------- # +# Route: POST /api/cancel-model-download-session +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_cancel_removes_active_session( + aiohttp_client, app, fresh_download_server +): + fresh_download_server.try_register( + "loras/x.safetensors", "https://huggingface.co/u.safetensors" + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/cancel-model-download-session", + json={"model_id": "loras/x.safetensors"}, + ) + assert resp.status == 200 + assert (await resp.json())["cancelled"] is True + assert not fresh_download_server.is_downloading("loras/x.safetensors") + + +@pytest.mark.asyncio +async def test_cancel_returns_404_when_none(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/cancel-model-download-session", + json={"model_id": "loras/nothing.safetensors"}, + ) + assert resp.status == 404 + assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING" + + +# --------------------------------------------------------------------------- # +# Unified availability response embeds metadata per id +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_availability_embeds_metadata(aiohttp_client, app): + """``file_size`` + ``is_hf_downloadable`` come back on the same + request as the state — no separate metadata endpoint.""" + results = { + "https://huggingface.co/a/b/resolve/main/free.safetensors": + MetadataProbeResult(file_size=1024, is_hf_downloadable=True), + "https://huggingface.co/g/r/resolve/main/gated.safetensors": + MetadataProbeResult(file_size=None, is_hf_downloadable=False), + } + + async def fake_probe(url): + return results[url] + + with patch( + "app.model_downloader.api.routes.probe_url", new=fake_probe + ): + client = await aiohttp_client(app) + resp = await client.post( + "/api/models-availability-status", + json={ + "models": { + "loras/free.safetensors": + "https://huggingface.co/a/b/resolve/main/free.safetensors", + "loras/gated.safetensors": + "https://huggingface.co/g/r/resolve/main/gated.safetensors", + } + }, + ) + assert resp.status == 200 + models = (await resp.json())["models"] + assert models["loras/free.safetensors"]["file_size"] == 1024 + assert models["loras/free.safetensors"]["is_hf_downloadable"] is True + assert models["loras/gated.safetensors"]["file_size"] is None + assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False