feat(model_downloader): add server-side model downloads with gated-repo support

Lets ComfyUI fetch the models a workflow needs directly on the server,
so users no longer have to locate each file and drop it into the correct
folder by hand.

Crucially it supports gated HuggingFace repositories: the user logs in
once via HuggingFace, after which the server can download models that
require license acceptance or authentication — previously a manual,
error-prone step. The frontend can surface per-model availability and
download progress through the accompanying API.
This commit is contained in:
Alex 2026-06-24 09:06:22 +03:00 committed by DoronGenzelHass
parent 418d272cfa
commit 351119eb05
16 changed files with 3180 additions and 0 deletions

View File

@ -0,0 +1,51 @@
"""URL allowlist for server-side model fetches.
Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows
agree on which URLs are eligible for download. Server-side allowlisting is
the primary SSRF defense for this subsystem workflow JSON is untrusted
input (anyone can hand-craft one), so we never let the server fetch URLs
outside this list.
"""
from urllib.parse import urlparse
# Frontend parity: ``missingModelDownload-*.js`` exports the same triple
# (Civitai / HuggingFace / localhost). Keyed by exact hostname → allowed
# schemes, and matched against the *parsed* host (not a raw string prefix),
# so URL-userinfo tricks can't slip past — see ``is_url_allowed``.
_ALLOWED_HOSTS = {
"huggingface.co": {"https"},
"civitai.com": {"https"},
"localhost": {"http"},
"127.0.0.1": {"http"},
}
# Frontend parity: same set as ``a = [...]`` in the bundle.
_ALLOWED_MODEL_EXTENSIONS = (
".safetensors",
".sft",
".ckpt",
".pth",
".pt",
)
def is_url_allowed(url: str) -> bool:
"""Check whether ``url`` is permitted as a server-side download source.
True only when the parsed host + scheme are allowlisted AND the path ends
in a model extension. Matching on ``parsed.hostname`` (not a string prefix)
defeats userinfo tricks like ``http://127.0.0.1:80@169.254.169.254/x.safetensors``,
whose real host is ``169.254.169.254``; the extension check rejects non-model
URLs on allowed hosts (e.g. ``huggingface.co/api/...``).
"""
if not isinstance(url, str) or not url:
return False
try:
parsed = urlparse(url)
except ValueError:
return False
host = parsed.hostname
if host is None or parsed.scheme not in _ALLOWED_HOSTS.get(host, ()):
return False
return any(parsed.path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS)

View File

@ -0,0 +1,359 @@
"""Aiohttp routes for the server-side model download subsystem.
Endpoint surface (all under ``/api/``, all kebab-case):
- ``POST /api/models-availability-status`` bulk status + metadata query.
- ``POST /api/download-models`` start a batch of downloads.
- ``POST /api/cancel-model-download-session`` cancel a single in-flight one.
- ``GET /api/hf-auth-token-status`` current HF login state.
- ``POST /api/hf-auth-login-start`` begin the HF OAuth flow.
- ``POST /api/hf-auth-logout`` drop the stored HF token.
The contract is intentionally narrow: only model_ids of the form
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
are accepted, and only URLs on the same allowlist the frontend already
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
to keep the server out of the SSRF business for this feature.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Literal, Optional
from aiohttp import web
from pydantic import BaseModel, ValidationError
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadSession,
)
from app.model_downloader.downloader import schedule_batch
from app.model_downloader.gated_detection import probe_url
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
from app.model_downloader.hf_auth.oauth import (
OAuthCallbackError,
OAuthInProgressError,
start_login_flow,
)
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_existing,
)
from app.model_downloader.api import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
def register_routes(app: web.Application) -> None:
"""Wire the model-downloader routes into the running aiohttp app.
Called once from ``server.py`` during ``PromptServer`` startup.
"""
app.add_routes(ROUTES)
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
ErrorCode = Literal[
"INVALID_JSON",
"INVALID_BODY",
"EMPTY_REQUEST",
"INVALID_MODEL_ID",
"URL_NOT_ALLOWED",
"ALREADY_AVAILABLE",
"ALREADY_DOWNLOADING",
"MODEL_NOT_DOWNLOADABLE",
"NOT_DOWNLOADING",
"HF_AUTH_NOT_ELIGIBLE",
"HF_AUTH_IN_PROGRESS",
"HF_AUTH_START_FAILED",
]
def _error(status: int, code: ErrorCode, message: str, details: dict | None = None) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _validation_error(code: ErrorCode, ve: ValidationError) -> web.Response:
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
return web.json_response(
payload.model_dump(mode="json", exclude_none=False),
status=status,
)
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
"""Parse a JSON body into a pydantic model or raise a 400 response."""
try:
raw = await request.json()
except json.JSONDecodeError:
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
return model.model_validate(raw)
except ValidationError as ve:
return _validation_error("INVALID_BODY", ve)
# ----- 1. availability status (unified: state + metadata per id) -----
@ROUTES.post("/api/models-availability-status")
async def models_availability_status(request: web.Request) -> web.Response:
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
State (``available`` / ``missing`` / ``downloading``) is cheap to
recompute per call. ``file_size`` and ``is_gated`` are cached
server-side per URL. ``is_hf_downloadable`` is recomputed every
call from the current token state that's what makes login + license
acceptance show up in the UI within one poll cycle without any
frontend cache plumbing.
"""
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
if isinstance(parsed, web.Response):
return parsed
items = list(parsed.models.items())
# Run all probes concurrently; each is internally cached per URL.
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
for (model_id, _url), probe in zip(items, probes):
try:
parse_model_id(model_id)
except InvalidModelId:
# Ill-formed identifier: report as missing without 400-ing the
# whole batch — the workflow author probably typo'd.
response_models[model_id] = schemas_out.ModelStatusEntry(
state="missing",
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
active = DOWNLOAD_SERVER.get(model_id)
if active is not None:
response_models[model_id] = schemas_out.ModelStatusEntry(
state="downloading",
progress=schemas_out.DownloadProgress(
bytes_downloaded=active.bytes_downloaded,
total_bytes=active.total_bytes,
progress=active.progress,
),
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
state: schemas_out.ModelState = (
"available" if resolve_existing(model_id) is not None else "missing"
)
response_models[model_id] = schemas_out.ModelStatusEntry(
state=state,
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
return _ok(schemas_out.AvailabilityStatusResponse(
models=response_models,
hf_auth=schemas_out.HfAuthStatus(
token_available=HF_AUTH_STORE.has_token(),
eligible=is_hf_auth_eligible(),
),
))
# ----- 2. start downloads -----
@ROUTES.post("/api/download-models")
async def download_models(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
if isinstance(parsed, web.Response):
return parsed
if not parsed.models:
return _error(400, "EMPTY_REQUEST", "No models supplied.")
# ----- precondition pass: validate everything BEFORE registering anything -----
# Atomic semantics: if any model fails any precondition (invalid id,
# not allow-listed URL, already on disk, already downloading, or gated),
# the entire request fails and no state is changed.
requested = list(parsed.models.items())
for model_id, url in requested:
try:
parse_model_id(model_id)
except InvalidModelId as e:
return _error(400, "INVALID_MODEL_ID", str(e),
{"model_id": model_id})
if not is_url_allowed(url):
return _error(
400, "URL_NOT_ALLOWED",
"Server-side downloads only accept HuggingFace, Civitai, "
"or localhost URLs ending in a known model extension.",
{"model_id": model_id, "url": url},
)
if resolve_existing(model_id) is not None:
return _error(409, "ALREADY_AVAILABLE",
f"Model already exists on disk: {model_id}",
{"model_id": model_id})
if DOWNLOAD_SERVER.is_downloading(model_id):
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress.",
{"model_id": model_id})
# Reachability check last — it's the only one that talks to the
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
# reflects current token access; for non-HF URLs it's None, and we
# treat that as "no info, proceed".
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
for (model_id, url), probe in zip(requested, probes):
if probe.is_hf_downloadable is False:
return _error(
400, "MODEL_NOT_DOWNLOADABLE",
f"Model {model_id} is gated on HuggingFace and the current "
f"server token (if any) does not grant access.",
{"model_id": model_id, "url": url},
)
# ----- registration pass: try_register is atomic per model_id -----
# Defensive: another request might have raced past our pre-check
# between the loop above and here. try_register handles that.
sessions: list[DownloadSession] = []
for model_id, url in requested:
session = DOWNLOAD_SERVER.try_register(model_id, url)
if session is None:
# Race: someone else got in. Roll back what we registered.
for s in sessions:
DOWNLOAD_SERVER.cancel(s.model_id)
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress (race).",
{"model_id": model_id})
sessions.append(session)
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
schedule_batch(sessions)
logging.info(
"[model_downloader] scheduled %d downloads: %s",
len(sessions), [s.model_id for s in sessions],
)
return _ok(schemas_out.DownloadModelsResponse(
accepted=True,
scheduled=[s.model_id for s in sessions],
), status=202)
# ----- 3. cancel a session -----
@ROUTES.post("/api/cancel-model-download-session")
async def cancel_model_download_session(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
if isinstance(parsed, web.Response):
return parsed
try:
parse_model_id(parsed.model_id)
except InvalidModelId as e:
return _error(400, "INVALID_MODEL_ID", str(e), {"model_id": parsed.model_id})
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
if not cancelled:
return _error(404, "NOT_DOWNLOADING",
f"No active download for {parsed.model_id}.",
{"model_id": parsed.model_id})
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
# ----- 4. HuggingFace OAuth status / login start / logout -----
@ROUTES.get("/api/hf-auth-token-status")
async def hf_auth_token_status(request: web.Request) -> web.Response:
"""Return whether the server holds a usable HF token + its username.
Used by the settings UI and (out-of-band) by the frontend on
login completion. ``token_available`` is true even if the cached
access_token is expired as long as a refresh_token exists, the
user is "logged in" from their perspective.
"""
token_present = HF_AUTH_STORE.has_token()
username: Optional[str] = None
if token_present:
# Resolve the username via whoami. Done in a worker thread because
# huggingface_hub's whoami is synchronous + blocks on a network call.
tok = await HF_AUTH_STORE.get_valid_token()
if tok is not None:
try:
username = await asyncio.to_thread(_whoami_username, tok.access_token)
except Exception as e:
logging.debug("[hf_auth] whoami failed: %s", e)
return _ok(schemas_out.HfAuthTokenStatusResponse(
token_available=token_present,
username=username,
))
def _whoami_username(token: str) -> Optional[str]:
"""Sync helper: ask HF for the user name attached to a token."""
from huggingface_hub import HfApi
info = HfApi().whoami(token=token)
if isinstance(info, dict):
return info.get("name") or info.get("fullname")
return None
@ROUTES.post("/api/hf-auth-login-start")
async def hf_auth_login_start(request: web.Request) -> web.Response:
"""Begin one OAuth attempt: bind the callback port, return the URL.
Rejected outright if this deployment isn't eligible (we don't
surface the option on multi-tenant / public-IP installs).
"""
if not is_hf_auth_eligible():
return _error(
403, "HF_AUTH_NOT_ELIGIBLE",
"This server is not eligible for interactive HuggingFace login. "
"It must be bound to a loopback address and not running in "
"--multi-user mode.",
)
try:
url = await start_login_flow()
except OAuthInProgressError:
return _error(
409, "HF_AUTH_IN_PROGRESS",
"Another HuggingFace login attempt is in progress. Try again "
"after it completes or times out.",
)
except OAuthCallbackError as e:
return _error(
503, "HF_AUTH_START_FAILED",
f"Could not start the HuggingFace login flow: {e}",
)
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
@ROUTES.post("/api/hf-auth-logout")
async def hf_auth_logout(request: web.Request) -> web.Response:
"""Drop the in-memory + on-disk HF token."""
HF_AUTH_STORE.clear()
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))

View File

@ -0,0 +1,41 @@
"""Request schemas for the model-downloader API.
Each endpoint accepts a small JSON body. Pydantic enforces the shape at
the boundary; route handlers operate only on validated values past that.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class AvailabilityStatusRequest(BaseModel):
"""``POST /api/models-availability-status``.
Sent by the frontend on each poll. Each entry is ``{model_id: url}``;
the URL is the one declared in ``properties.models[i].url`` in the
workflow JSON and lets the server compute per-id metadata
(``file_size`` + ``is_hf_downloadable``) on the same request.
"""
models: dict[str, str] = Field(default_factory=dict)
class DownloadModelsRequest(BaseModel):
"""``POST /api/download-models``.
Same shape as the metadata request the URL for each model_id.
Returns immediately after validation and scheduling.
"""
models: dict[str, str] = Field(default_factory=dict)
class CancelDownloadSessionRequest(BaseModel):
"""``POST /api/cancel-model-download-session``."""
model_id: str
__all__ = [
"AvailabilityStatusRequest",
"DownloadModelsRequest",
"CancelDownloadSessionRequest",
]

View File

@ -0,0 +1,81 @@
"""Response schemas for the model-downloader API."""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel
ModelState = Literal["available", "missing", "downloading"]
class DownloadProgress(BaseModel):
"""Embedded in a model entry when its state is ``downloading``."""
bytes_downloaded: int
total_bytes: Optional[int] = None
progress: Optional[float] = None # fraction in [0,1]; null until total known
class ModelStatusEntry(BaseModel):
"""Everything the UI needs to render one row, in one shot.
``state`` reflects what the server has on disk + in-flight; ``file_size``
and ``is_hf_downloadable`` come from probes (intrinsic; cached).
The HF fields are populated for every poll (cached on the server),
so license-acceptance flips show up within one poll interval without
any frontend cache invalidation.
"""
state: ModelState
progress: Optional[DownloadProgress] = None
file_size: Optional[int] = None
# HF-only: True iff the server can fetch this URL with current auth
# state. False iff gated and lacking access. None for non-HF URLs.
is_hf_downloadable: Optional[bool] = None
class HfAuthStatus(BaseModel):
"""Snapshot of HF login state, embedded in availability response."""
token_available: bool
eligible: bool
class AvailabilityStatusResponse(BaseModel):
models: dict[str, ModelStatusEntry]
hf_auth: HfAuthStatus
class DownloadModelsResponse(BaseModel):
accepted: bool
scheduled: list[str]
class CancelDownloadSessionResponse(BaseModel):
cancelled: bool
class HfAuthTokenStatusResponse(BaseModel):
token_available: bool
username: Optional[str] = None
class HfAuthLoginStartResponse(BaseModel):
authorize_url: str
class HfAuthLogoutResponse(BaseModel):
logged_out: bool
__all__ = [
"ModelState",
"DownloadProgress",
"ModelStatusEntry",
"HfAuthStatus",
"AvailabilityStatusResponse",
"DownloadModelsResponse",
"CancelDownloadSessionResponse",
"HfAuthTokenStatusResponse",
"HfAuthLoginStartResponse",
"HfAuthLogoutResponse",
]

View File

@ -0,0 +1,179 @@
"""Process-wide registry of in-flight model downloads.
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
server-side model fetch. Designed to be safe with multiple concurrent
clients hitting the API: each model_id has at most one active session,
and the API rejects requests that conflict with in-flight downloads.
Cancellation is cooperative the download loop checks ``is_active`` on
its own session between chunks and raises ``DownloadCancelled`` when the
session has been removed from the registry. This avoids the complications
of ``Task.cancel()`` from outside the loop while still giving deterministic
rollback semantics (the worker is responsible for deleting its own
``.tmp`` on the cancel path).
"""
from __future__ import annotations
import logging
import os
import threading
from dataclasses import dataclass, field
from typing import Optional
from app.model_downloader.paths import iter_all_tmp_paths
class DownloadCancelled(Exception):
"""Raised by the streaming loop when its session has been removed
from the registry (cancellation request) and the worker should roll
back its ``.tmp`` file."""
@dataclass
class DownloadSession:
"""One in-flight download.
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
byte arrives and we know whether the response carries a
``Content-Length``. ``total_bytes`` mirrors that header when present.
"""
model_id: str
url: str
progress: Optional[float] = None
bytes_downloaded: int = 0
total_bytes: Optional[int] = None
# Sequence number used solely as identity for the cancellation check —
# so that "cancel + restart" doesn't get confused by stale workers.
epoch: int = field(default_factory=lambda: 0)
class DownloadServer:
"""Singleton registry of active downloads.
All mutation goes through this object so concurrent route handlers
see a consistent view. The ``_lock`` is a plain threading lock
because the registry is consulted from both the asyncio event-loop
thread (route handlers) and from any worker coroutines spawned to
perform downloads.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._sessions: dict[str, DownloadSession] = {}
self._epoch_counter = 0
self._orphan_sweep_done = False
# ----- lifecycle -----
def sweep_orphan_tmp_files(self) -> None:
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
Deferred off the import path so module load doesn't block on
filesystem I/O against potentially-slow mounts. Each route handler
that might create a new ``.tmp`` runs this exactly once.
"""
with self._lock:
if self._orphan_sweep_done:
return
self._orphan_sweep_done = True
for path in iter_all_tmp_paths():
try:
os.remove(path)
logging.info("[model_downloader] removed orphan tmp file: %s", path)
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
# ----- queries -----
def is_downloading(self, model_id: str) -> bool:
with self._lock:
return model_id in self._sessions
def get(self, model_id: str) -> Optional[DownloadSession]:
with self._lock:
return self._sessions.get(model_id)
def snapshot(self) -> dict[str, DownloadSession]:
"""Return a shallow copy of the current sessions map."""
with self._lock:
return dict(self._sessions)
# ----- mutations -----
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
"""Atomically register a new session iff none exists for ``model_id``.
Returns the new session on success, ``None`` if a session is already
in flight. Callers must check the return value the caller is the
sole owner of the session it gets back.
"""
with self._lock:
if model_id in self._sessions:
return None
self._epoch_counter += 1
session = DownloadSession(
model_id=model_id,
url=url,
epoch=self._epoch_counter,
)
self._sessions[model_id] = session
return session
def update_progress(
self,
session: DownloadSession,
bytes_downloaded: int,
total_bytes: Optional[int],
) -> None:
"""Update progress on a session. No-op if the session has been
removed (cancelled) caller should check ``is_active`` separately."""
with self._lock:
current = self._sessions.get(session.model_id)
if current is None or current.epoch != session.epoch:
return
current.bytes_downloaded = bytes_downloaded
current.total_bytes = total_bytes
if total_bytes and total_bytes > 0:
current.progress = min(1.0, bytes_downloaded / total_bytes)
def is_active(self, session: DownloadSession) -> bool:
"""True iff this exact session is still the registered one for
its model_id. False after cancellation, after completion, or if
another session has replaced it."""
with self._lock:
current = self._sessions.get(session.model_id)
return current is not None and current.epoch == session.epoch
def finish(self, session: DownloadSession) -> None:
"""Remove a completed (or cancelled) session from the registry.
Safe to call multiple times. Only removes if the epoch matches
we never accidentally evict a *newer* session for the same model_id.
"""
with self._lock:
current = self._sessions.get(session.model_id)
if current is not None and current.epoch == session.epoch:
del self._sessions[session.model_id]
def reset_for_tests(self) -> None:
"""Clear all sessions and reset the epoch counter. Test-only."""
with self._lock:
self._sessions.clear()
self._epoch_counter = 0
def cancel(self, model_id: str) -> bool:
"""Remove the session registered for ``model_id``.
Returns True if there was an active session to cancel. The worker
will discover the cancellation on its next ``is_active`` check
and roll back its ``.tmp`` file.
"""
with self._lock:
if model_id in self._sessions:
del self._sessions[model_id]
return True
return False
DOWNLOAD_SERVER = DownloadServer()

View File

@ -0,0 +1,216 @@
"""Streaming download worker with progress reporting and cancellation.
Each download writes to ``<final_path>.tmp`` and atomically renames into
place on success. Between chunks the worker checks the registry for
cancellation (via ``DownloadServer.is_active``) and rolls back its
``.tmp`` on cancel or on any error.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
import aiohttp
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadCancelled,
DownloadSession,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url
from app.model_downloader.http_client import get_session, parse_content_length
from app.model_downloader.paths import resolve_destination
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
async def stream_to_disk(session: DownloadSession) -> str:
"""Run a single download to completion or cancellation.
Returns the final on-disk path on success. Removes the ``.tmp`` and
raises on cancellation or failure. The session is finished
(removed from the registry) exactly once, here callers do not
need to call ``DOWNLOAD_SERVER.finish`` themselves.
"""
final_path, tmp_path = resolve_destination(session.model_id, session.epoch)
os.makedirs(os.path.dirname(final_path), exist_ok=True)
# Wipe any stale .tmp from a previous failed attempt before we start —
# otherwise a partial body could masquerade as our completed download
# when the rename finally happens.
_remove_if_exists(tmp_path)
bytes_seen = 0
try:
http = await get_session()
headers = _auth_headers_for(session.url)
logging.info(
"[model_downloader] starting GET %s (auth=%s)",
session.url, "yes" if "Authorization" in headers else "no",
)
async with http.get(
session.url,
allow_redirects=True,
timeout=REQUEST_TIMEOUT,
headers=headers,
) as resp:
if resp.status != 200:
# Capture a snippet of the response body so 4xx/5xx aren't
# opaque in the logs — HF returns JSON or HTML with a
# human-readable reason on failures.
body_snippet = await _read_short(resp)
logging.warning(
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
session.url, resp.status, str(resp.url), body_snippet,
)
raise DownloadError(
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
status=resp.status,
)
total = parse_content_length(resp.headers.get("Content-Length"))
DOWNLOAD_SERVER.update_progress(session, 0, total)
with open(tmp_path, "wb") as f:
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
# Cancellation check between chunks. Cheap and means
# cancellation latency is bounded by one chunk plus
# one ``write()`` — typically well under a second
# even on slow disks.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
f.write(chunk)
bytes_seen += len(chunk)
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
# Final cancellation check before we promote the .tmp to the real
# filename — avoids the awkward case where cancel arrives during
# the very last chunk and we'd otherwise commit anyway.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
# Size verification before commit. aiohttp already raises
# ClientPayloadError on a truncated Content-Length/chunked body,
# but this also catches the HTTP/1.0-style case (no Content-Length
# + Connection: close) where a short read can masquerade as a
# complete download.
if total is not None and bytes_seen != total:
raise DownloadError(
f"size mismatch for {session.model_id}: "
f"got {bytes_seen} of {total} bytes from {session.url}"
)
# Atomic rename. os.replace is atomic within the same filesystem,
# which is guaranteed here because tmp lives alongside final_path.
os.replace(tmp_path, final_path)
logging.info(
"[model_downloader] downloaded %s (%d bytes) from %s",
session.model_id, bytes_seen, session.url,
)
return final_path
except DownloadCancelled:
logging.info("[model_downloader] cancelled: %s", session.model_id)
_remove_if_exists(tmp_path)
raise
except Exception as e:
logging.warning(
"[model_downloader] failed: %s from %s: %s: %s",
session.model_id, session.url, type(e).__name__, e,
exc_info=True,
)
_remove_if_exists(tmp_path)
raise
finally:
# In all terminal states (success / cancel / error) drop the
# session from the registry. Idempotent — only removes if we're
# still the live epoch for this model_id.
DOWNLOAD_SERVER.finish(session)
class DownloadError(Exception):
"""Network / protocol error during a download."""
def __init__(self, message: str, status: Optional[int] = None) -> None:
super().__init__(message)
self.status = status
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
"""Read up to ``limit`` bytes of a response body for logging.
Used to surface the JSON/HTML reason from an HF non-2xx response in
server logs instead of just the status code. Best-effort: any
error here is swallowed.
"""
try:
raw = await resp.content.read(limit)
return raw.decode("utf-8", errors="replace").strip()
except Exception:
return "<unreadable>"
def _auth_headers_for(url: str) -> dict[str, str]:
"""Return any auth headers we should add to the GET for ``url``.
For HuggingFace URLs we inject the user's OAuth access token as a
Bearer header this is HF's documented way to access gated repos
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
other host we send no extra headers; allowlisted public files
don't need them and we don't want to leak tokens to other hosts.
"""
if not is_hf_url(url):
return {}
tok = HF_AUTH_STORE.get_token_sync()
if tok is None or not tok.access_token:
return {}
return {"Authorization": f"Bearer {tok.access_token}"}
def _remove_if_exists(path: str) -> None:
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
"""Run a list of sessions one after the other.
Each session is independent: a failure or cancellation of one does
not abort the rest. Cancellations are observable via the registry
*before* a given download starts, so a session that's been
pre-cancelled (cancel before the worker reached it) just gets skipped.
"""
for session in sessions:
# If the session got cancelled before its turn, skip without
# touching disk. This is what makes the per-request "sequential
# but cancellable" semantic work.
if not DOWNLOAD_SERVER.is_active(session):
DOWNLOAD_SERVER.finish(session)
continue
try:
await stream_to_disk(session)
except DownloadCancelled:
# Already logged + tmp removed inside stream_to_disk.
continue
except Exception:
# stream_to_disk already logged. Continue with the rest of the batch.
continue
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
"""Kick off ``run_batch_sequential`` on the running event loop.
Returned task is fire-and-forget; the API handler returns immediately
after scheduling and clients observe progress via the polling endpoints.
"""
return asyncio.create_task(run_batch_sequential(sessions))

View File

@ -0,0 +1,245 @@
"""Per-URL probes for the unified availability endpoint.
Three cached/derived facts per URL:
- ``is_gated`` intrinsic to the model; cached forever once known.
Determined by ``auth_check(repo_id, token=None)``:
``GatedRepoError`` True, success False.
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
call. For non-gated URLs this is trivially True
(no HF call needed). For gated URLs we run
``auth_check`` with the stored token each call.
- ``file_size`` intrinsic to the file. Cached forever once
determined (including ``None`` on transient
failure we don't retry). We only attempt the
HEAD when we already know the URL is downloadable
to us; that way a failed-because-gated probe
never lands as a cached ``None``.
Caches are per-process, in-memory; small, no eviction needed for the
workflow-scale (~tens of URLs). Concurrent calls for the same URL
deduplicate via per-URL ``asyncio.Lock``.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Optional
import aiohttp
from huggingface_hub import HfApi
from huggingface_hub.errors import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
)
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
from app.model_downloader.http_client import get_session, parse_content_length
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
@dataclass
class ProbeResult:
file_size: Optional[int]
is_hf_downloadable: Optional[bool]
# --- caches -------------------------------------------------------------- #
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
# model — never changes for a given URL.
_is_gated_cache: dict[str, bool] = {}
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
# was attempted and produced no answer. **Only populated when we knew
# the URL was downloadable to us at probe time** — so gated-without-
# access never lands a ``None`` here that we'd be stuck with after login.
_file_size_cache: dict[str, Optional[int]] = {}
# Per-URL locks for single-flight probes — when multiple polls arrive
# in the same tick for the same URL, exactly one of them runs the HF
# call and the others wait on the result.
_locks: dict[str, asyncio.Lock] = {}
def _lock_for(url: str) -> asyncio.Lock:
lock = _locks.get(url)
if lock is None:
lock = asyncio.Lock()
_locks[url] = lock
return lock
def clear_caches_for_tests() -> None:
"""Test-only: drop everything."""
_is_gated_cache.clear()
_file_size_cache.clear()
_locks.clear()
# --- public entrypoint --------------------------------------------------- #
async def probe_url(url: str) -> ProbeResult:
"""Return downloadability + size for one URL, using caches where safe."""
if not is_url_allowed(url):
return ProbeResult(file_size=None, is_hf_downloadable=None)
if not is_hf_url(url):
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
# Size we still cache so we don't HEAD on every poll.
size = await _get_or_probe_size(url, token=None)
return ProbeResult(file_size=size, is_hf_downloadable=None)
repo_id = repo_id_from_url(url)
if repo_id is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Determine intrinsic gating once.
gated = await _resolve_is_gated(url, repo_id)
if gated is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Compute current-token downloadability per call.
tok = await HF_AUTH_STORE.get_valid_token()
token_str: Optional[str] = tok.access_token if tok else None
if not gated:
is_hf_downloadable: Optional[bool] = True
else:
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
if is_hf_downloadable is True:
size = await _get_or_probe_size(url, token=token_str)
else:
# Skip the HEAD entirely — would 401 and we'd be stuck with
# cached None that survives a later login.
size = None
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
# --- gated/auth probes --------------------------------------------------- #
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
"""Decide once whether ``repo_id`` is a gated repo."""
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
async with _lock_for(url):
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
# Probe anonymously (token=None) on purpose: an unauthenticated
# auth_check is what makes HF raise GatedRepoError for gated repos.
# With a token, a gated-but-accepted repo would succeed and look
# ungated.
try:
await asyncio.to_thread(_auth_check_sync, repo_id, None)
_is_gated_cache[url] = False
return False
except GatedRepoError:
_is_gated_cache[url] = True
return True
except RepositoryNotFoundError:
# Repo doesn't exist publicly. Treat as gated — we can't
# serve it without auth, and an authenticated check might
# still succeed if it's a private repo the user can see.
_is_gated_cache[url] = True
return True
except (HfHubHTTPError, Exception) as e:
logging.debug(
"[hf_auth] is_gated probe failed for %s (will retry): %s",
repo_id, e,
)
return None # don't cache; retry next call
async def _auth_check_with_token(
repo_id: str, token: Optional[str]
) -> Optional[bool]:
"""Auth-check with the supplied token. True/False/None per outcome."""
try:
await asyncio.to_thread(_auth_check_sync, repo_id, token)
return True
except GatedRepoError:
return False
except RepositoryNotFoundError:
return False
except HfHubHTTPError as e:
# 401/403 covers org-SSO-required, revoked tokens, and similar —
# all of which mean "can't fetch right now" from the user's POV.
status = getattr(getattr(e, "response", None), "status_code", None)
if status in (401, 403):
return False
logging.debug(
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
)
return None
except Exception as e:
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
return None
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
HfApi().auth_check(repo_id, token=token)
# --- size probe ---------------------------------------------------------- #
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
"""Return the cached size or HEAD the URL once and cache the result."""
if url in _file_size_cache:
return _file_size_cache[url]
async with _lock_for(url):
if url in _file_size_cache:
return _file_size_cache[url]
size = await _probe_size_once(url, token=token)
_file_size_cache[url] = size
return size
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
"""HEAD the URL and return the file size in bytes, or None on failure.
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
The real file size lives in the non-standard ``X-Linked-Size`` header
on that 302 response (``Content-Length`` is the redirect-body length).
Disabling redirect-follow lets us read either header on the same
response:
- LFS files: 302 + ``X-Linked-Size``
- Small/non-LFS files: 200 + ``Content-Length``
"""
headers = {"Authorization": f"Bearer {token}"} if token else {}
try:
session = await get_session()
async with session.head(
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
) as resp:
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
if linked is not None:
return linked
if resp.status == 200:
return parse_content_length(resp.headers.get("Content-Length"))
return None
except (aiohttp.ClientError, TimeoutError, OSError):
return None
# Backward-compat shim so consumers that still import the old name keep
# building during the refactor; can be removed once routes are updated.
MetadataProbeResult = ProbeResult

View File

@ -0,0 +1,121 @@
"""In-memory token cache with lazy disk persistence + refresh.
Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask
``get_valid_token()``; the store transparently refreshes from disk
on first use, refreshes via the OAuth refresh_token if the cached
access_token is expired, and returns ``None`` if neither path works.
The refresh path imports ``oauth.refresh_access_token`` lazily to
avoid an import cycle (oauth needs the store to save tokens it
acquires).
"""
from __future__ import annotations
import logging
import threading
from typing import Optional
from app.model_downloader.hf_auth.token_store import (
Token,
delete_token,
load_token,
save_token,
)
class HfAuthStore:
def __init__(self) -> None:
self._lock = threading.Lock()
self._token: Optional[Token] = None
self._loaded_from_disk = False
def _ensure_loaded(self) -> None:
"""Read the disk token into memory on first access."""
if self._loaded_from_disk:
return
with self._lock:
if self._loaded_from_disk:
return
self._token = load_token()
self._loaded_from_disk = True
def has_token(self) -> bool:
"""Cheap check: is there any token in memory?
Does not attempt refresh; an expired-but-refreshable token still
counts as "logged in" from the user's perspective.
"""
self._ensure_loaded()
return self._token is not None
def _store_token_locked(self, token: Token) -> None:
"""Set the in-memory token and persist it to disk.
Caller must already hold ``self._lock``. Keeping the disk write inside
the lock means memory and disk flip together a concurrent ``clear()``
or refresh can't interleave between them.
"""
self._token = token
self._loaded_from_disk = True
save_token(token)
def set_token(self, token: Token) -> None:
"""Replace the in-memory token and persist to disk (atomically)."""
with self._lock:
self._store_token_locked(token)
def clear(self) -> None:
"""Forget the token in memory and on disk (logout)."""
with self._lock:
self._token = None
self._loaded_from_disk = True
delete_token()
def get_token_sync(self) -> Optional[Token]:
"""Return the cached token without refreshing.
Sync callers (e.g. constructing an Authorization header in a
non-async path) use this. They accept an expired token over
``None``; HF will simply return 401 and the caller can decide
what to do.
"""
self._ensure_loaded()
return self._token
async def get_valid_token(self) -> Optional[Token]:
"""Return a fresh token, refreshing via OAuth if necessary.
Returns ``None`` if there's no cached token at all, or if the
cached token is expired and refresh failed. Callers should
treat that as "user is not logged in".
"""
self._ensure_loaded()
with self._lock:
tok = self._token
if tok is None:
return None
if tok.is_valid():
return tok
if not tok.refresh_token:
return None
# Lazy import to avoid the oauth ↔ store import cycle.
from app.model_downloader.hf_auth.oauth import refresh_access_token
try:
refreshed = await refresh_access_token(tok.refresh_token)
except Exception as e:
logging.warning("[hf_auth] token refresh failed: %s", e)
return None
with self._lock:
# If a logout (clear) or another update replaced the token while we
# were awaiting the refresh, don't resurrect the old session.
if self._token is not tok:
return None
self._store_token_locked(refreshed)
return refreshed
HF_AUTH_STORE = HfAuthStore()

View File

@ -0,0 +1,55 @@
"""Whether this deployment is allowed to do interactive HF OAuth.
We only let the server hold a HuggingFace token under a strict trust
assumption: this is a *single tenant local* install. Concretely:
- The server is bound to a loopback address. SSH tunneling /
reverse-proxies can defeat this, but it's the strongest signal
we have without an authentication system.
- ``--multi-user`` is off. A shared token used implicitly by multiple
declared users would be a footgun one user's gated downloads
would silently authenticate as another.
Anything else and the frontend hides the HF login UI entirely; gated
models continue to show the "acquire it manually" message.
"""
from __future__ import annotations
import ipaddress
import socket
from comfy.cli_args import args
def _is_loopback(host: str | None) -> bool:
"""Duplicates ``server.is_loopback`` (small, no shared module yet).
Resolves a host or IP literal to whether it lives on the loopback
interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for
``0.0.0.0`` / ``::`` because those are bind-all wildcards, not
loopback.
"""
if host is None:
return False
try:
return ipaddress.ip_address(host).is_loopback
except ValueError:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for _family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
loopback = True
except socket.gaierror:
pass
return loopback
def is_hf_auth_eligible() -> bool:
"""True iff this deployment may surface the HF OAuth flow."""
return _is_loopback(args.listen) and not args.multi_user

View File

@ -0,0 +1,301 @@
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
Wired so that ``POST /api/hf-auth-login-start`` can:
1. Generate state + PKCE verifier/challenge in this process.
2. Spin up a short-lived loopback HTTP server at port 41954 to
receive the redirect callback from HF.
3. Return the ``authorize_url`` for the frontend to open in a new tab.
After the user grants consent on huggingface.co, HF redirects to the
local callback URL with ``code`` and ``state``. The callback server
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
itself down.
Before this can be exercised end-to-end a maintainer must register a
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
below. See the comment above the constant for the exact steps.
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import secrets
import threading
import time
import aiohttp
from aiohttp import web
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.token_store import Token
from app.model_downloader.http_client import get_session
# --- HF OAuth app registration -------------------------------------------- #
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
# under a Comfy-Org-controlled HF account and substitute its client_id here.
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
# ("HuggingFace OAuth app setup" section). Short version:
# 1. huggingface.co → Settings → Connected Apps → "Create app"
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
# ``gated-repos`` (Repository Access). Leave everything else off.
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
# change ``CALLBACK_PORT``.
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
# The client_id is not a secret (it travels through the user's browser in
# plaintext); HF's "Public app" type means there's no client secret to
# manage — PKCE replaces it.
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
CALLBACK_HOST = "127.0.0.1"
CALLBACK_PORT = 41954
CALLBACK_PATH = "/api/auth/huggingface/callback"
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
TOKEN_URL = "https://huggingface.co/oauth/token"
_TOKEN_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=30)
# Minimal scope set for the feature:
# - openid : required by HF when the app uses OIDC at all
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
# settings UI; cosmetic but expected
# - gated-repos : grants the token enough to call ``auth_check`` and
# download files from public gated repos the user has
# accepted the license for. The wider ``read-repos`` scope
# would also work (it includes ``gated-repos``) but it
# additionally grants private-repo read access, which we
# don't need and which makes the consent screen scarier
# for the user.
SCOPE = "openid profile gated-repos"
# Maximum time the callback server stays up waiting for the user to
# complete consent on huggingface.co. Past this, the port closes and
# the user has to click "Log in" again.
CALLBACK_TIMEOUT_SECS = 300
# Process-wide lock so two simultaneous /api/hf-auth-login-start
# requests don't fight over port CALLBACK_PORT.
_OAUTH_LOCK = threading.Lock()
class OAuthInProgressError(Exception):
"""Another OAuth attempt is already running."""
class OAuthCallbackError(Exception):
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
# --- PKCE primitives ------------------------------------------------------ #
def _make_pkce() -> tuple[str, str, str]:
"""Return ``(verifier, challenge, state)``.
Verifier never leaves this process. Challenge and state travel
through the user's browser. State is checked on the callback to
prevent a malicious cross-origin redirect from injecting a token.
"""
verifier = secrets.token_urlsafe(64)
challenge = (
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
.rstrip(b"=")
.decode("ascii")
)
state = secrets.token_urlsafe(32)
return verifier, challenge, state
def _build_authorize_url(challenge: str, state: str) -> str:
from urllib.parse import urlencode
params = {
"client_id": HF_CLIENT_ID,
"redirect_uri": REDIRECT_URI,
"response_type": "code",
"scope": SCOPE,
"state": state,
"code_challenge": challenge,
"code_challenge_method": "S256",
}
return f"{AUTHORIZE_URL}?{urlencode(params)}"
# --- Token exchange ------------------------------------------------------- #
async def _exchange_code(code: str, verifier: str) -> Token:
"""Trade the authorization code for an access+refresh token pair."""
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": REDIRECT_URI,
"client_id": HF_CLIENT_ID,
"code_verifier": verifier,
}
session = await get_session()
async with session.post(TOKEN_URL, data=data, timeout=_TOKEN_REQUEST_TIMEOUT) as resp:
resp.raise_for_status()
body = await resp.json()
return Token(
access_token=body["access_token"],
refresh_token=body.get("refresh_token"),
expires_at=time.time() + float(body.get("expires_in", 3600)),
scope=body.get("scope", SCOPE),
)
async def refresh_access_token(refresh_token: str) -> Token:
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": HF_CLIENT_ID,
}
session = await get_session()
async with session.post(TOKEN_URL, data=data, timeout=_TOKEN_REQUEST_TIMEOUT) as resp:
resp.raise_for_status()
body = await resp.json()
return Token(
access_token=body["access_token"],
# If HF doesn't rotate refresh tokens, keep using the existing one.
refresh_token=body.get("refresh_token", refresh_token),
expires_at=time.time() + float(body.get("expires_in", 3600)),
scope=body.get("scope", SCOPE),
)
# --- Callback server ------------------------------------------------------ #
async def start_login_flow() -> str:
"""Begin one OAuth attempt: spawn the callback server, return the URL.
Returns the URL the frontend should open in a new tab. Raises
``OAuthInProgressError`` if another attempt is already running.
The callback server runs in the background until the user
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
either way the lock + port are released afterward.
"""
if not _OAUTH_LOCK.acquire(blocking=False):
raise OAuthInProgressError()
try:
verifier, challenge, state = _make_pkce()
authorize_url = _build_authorize_url(challenge, state)
ready: asyncio.Future[None] = asyncio.get_event_loop().create_future()
except BaseException:
# Failed before handing the lock to the callback-server task: release it
# here. (Once the task is spawned, it owns releasing the lock.)
_OAUTH_LOCK.release()
raise
asyncio.create_task(_run_callback_server(verifier, state, ready))
# Don't return the URL until the callback server is actually bound and
# listening — otherwise HF could redirect to a port nothing is serving and
# the login would silently dead-end. ``ready`` raises if the bind failed.
await ready
return authorize_url
async def _run_callback_server(
verifier: str, expected_state: str, ready: "asyncio.Future[None]"
) -> None:
"""Listen for HF's redirect once, capture the token, then shut down.
Signals ``ready`` once the port is bound (or with an exception if the bind
fails), so ``start_login_flow`` only hands back a URL on a live server.
"""
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
async def handler(request: web.Request) -> web.Response:
try:
if request.query.get("state") != expected_state:
return web.Response(status=400, text="state mismatch")
err = request.query.get("error")
if err:
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
return web.Response(status=400, text=f"OAuth error: {err}")
code = request.query.get("code")
if not code:
return web.Response(status=400, text="missing code")
tok = await _exchange_code(code, verifier)
if not received.done():
received.set_result(tok)
return web.Response(
content_type="text/html",
text=(
"<html><body style='font-family:sans-serif;padding:40px'>"
"<h2>HuggingFace login successful</h2>"
"<p>You can close this tab and return to ComfyUI.</p>"
"</body></html>"
),
)
except Exception as exc:
if not received.done():
received.set_exception(exc)
return web.Response(status=500, text=str(exc))
app = web.Application()
app.router.add_get(CALLBACK_PATH, handler)
runner = web.AppRunner(app)
try:
await runner.setup()
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
await site.start()
except Exception as e:
# Couldn't bind the callback port (commonly already in use). Tell the
# waiting start_login_flow via ``ready`` so it surfaces a clear error
# instead of returning a dead URL, and release the lock for next time.
logging.warning("[hf_auth] could not start callback server: %s", e)
if not ready.done():
ready.set_exception(
OAuthCallbackError(f"could not bind callback port {CALLBACK_PORT}: {e}")
)
_OAUTH_LOCK.release()
return
# Bound and listening — now it's safe for start_login_flow to return the URL.
if not ready.done():
ready.set_result(None)
try:
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
except asyncio.TimeoutError:
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
return
except OAuthCallbackError as e:
logging.warning("[hf_auth] OAuth callback error: %s", e)
return
except Exception as e:
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
return
else:
HF_AUTH_STORE.set_token(token)
logging.info("[hf_auth] OAuth login complete")
finally:
await runner.cleanup()
if _OAUTH_LOCK.locked():
_OAUTH_LOCK.release()
def is_login_in_progress() -> bool:
"""True iff a callback server is currently bound + waiting."""
return _OAUTH_LOCK.locked()
# Re-export for callers that only want the URL builder (e.g. tests).
__all__ = [
"start_login_flow",
"refresh_access_token",
"is_login_in_progress",
"OAuthInProgressError",
"CALLBACK_TIMEOUT_SECS",
]

View File

@ -0,0 +1,94 @@
"""On-disk persistence for the HuggingFace OAuth token.
The token shape mirrors what HF returns on the token exchange: an
``access_token``, an optional ``refresh_token``, the absolute epoch at
which the access token expires, and the granted scope. We persist
this so logging in once survives ComfyUI restarts under the internal
``__hf_auth`` system-user directory; the file is mode ``0600`` so only
the OS user can read it.
"""
from __future__ import annotations
import json
import logging
import os
import stat
import time
from dataclasses import asdict, dataclass
from typing import Optional
import folder_paths
# Treat a token as expired this many seconds before its server-reported
# ``expires_at`` so we don't try to use a token mid-request only for it
# to flip stale between auth_check and the actual GET.
EXPIRY_BUFFER_SECS = 60
TOKEN_FILENAME = "hf_auth_token.json"
@dataclass
class Token:
"""One OAuth token + the metadata we need to use it."""
access_token: str
refresh_token: Optional[str]
expires_at: float # absolute epoch seconds
scope: str = ""
def is_valid(self) -> bool:
"""True iff we can use this token right now."""
return (
bool(self.access_token)
and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS)
)
def _token_dir() -> str:
return folder_paths.get_system_user_directory("hf_auth")
def _token_path() -> str:
return os.path.join(_token_dir(), TOKEN_FILENAME)
def load_token() -> Optional[Token]:
"""Read the persisted token, returning ``None`` if absent or corrupt."""
path = _token_path()
if not os.path.exists(path):
return None
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return Token(**data)
except (OSError, json.JSONDecodeError, TypeError) as e:
logging.warning("[hf_auth] could not load token at %s: %s", path, e)
return None
def save_token(token: Token) -> None:
"""Atomically write the token with 0600 permissions."""
path = _token_path()
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + ".tmp"
fd = os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as f:
json.dump(asdict(token), f)
os.replace(tmp, path)
try:
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
except OSError as e:
# On Windows / weird filesystems chmod may be a no-op; not fatal.
logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e)
def delete_token() -> None:
"""Remove the persisted token; no-op if it doesn't exist."""
path = _token_path()
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[hf_auth] could not remove token at %s: %s", path, e)

View File

@ -0,0 +1,41 @@
"""Parsers for the ``huggingface.co`` URL shape we accept in workflows.
The download API accepts URLs of the form
``https://huggingface.co/<org>/<repo>/resolve/<rev>/<path/to/file>``.
We need to recover ``<org>/<repo>`` (the *repo_id*) from such URLs for
``huggingface_hub`` API calls (notably ``HfApi.auth_check``).
"""
from __future__ import annotations
from typing import Optional
from urllib.parse import urlparse
_HF_HOST = "huggingface.co"
def is_hf_url(url: str) -> bool:
"""Cheap host check — does this URL point at huggingface.co?"""
try:
return urlparse(url).hostname == _HF_HOST
except ValueError:
return False
def repo_id_from_url(url: str) -> Optional[str]:
"""Extract ``<org>/<repo>`` from an HF model file URL.
Returns ``None`` if the URL isn't on huggingface.co or doesn't look
like a model-file URL. The expected shape is
``/<org>/<repo>/resolve/<rev>/<path>`` anything else
(datasets, spaces, /tree/, /blob/, ) we treat as out of scope here.
"""
if not is_hf_url(url):
return None
parts = urlparse(url).path.lstrip("/").split("/")
if len(parts) < 4 or parts[2] != "resolve":
return None
org, repo = parts[0], parts[1]
if not org or not repo:
return None
return f"{org}/{repo}"

View File

@ -0,0 +1,63 @@
"""Lazy module-level aiohttp ClientSession.
A single shared session means TLS handshakes are reused across HEAD probes
and the subsequent GETs to the same host (HuggingFace is the dominant
case), which is a noticeable speedup on cold connections.
We deliberately don't close the session at process exit — aiohttp's
warning about unclosed sessions is benign at shutdown, and adding atexit
plumbing buys nothing because the OS reclaims the sockets anyway. The
session lifetime is the lifetime of the Python process.
"""
from __future__ import annotations
import asyncio
import ssl
from typing import Optional
import aiohttp
import certifi
# Larger per-host pool than aiohttp's default (=100 total / =0 per host)
# so concurrent gated probes + a download to the same host don't queue.
_CONNECTOR_LIMIT_PER_HOST = 8
_session: Optional[aiohttp.ClientSession] = None
_lock = asyncio.Lock()
def ssl_context() -> ssl.SSLContext:
"""TLS context pinned to certifi's CA bundle.
aiohttp's default context uses the OS trust store, which isn't wired up
on some Python installs (python.org macOS, slim containers) there TLS
to huggingface.co fails with CERTIFICATE_VERIFY_FAILED.
"""
return ssl.create_default_context(cafile=certifi.where())
async def get_session() -> aiohttp.ClientSession:
"""Return the shared session, creating it on first call."""
global _session
if _session is not None and not _session.closed:
return _session
async with _lock:
if _session is None or _session.closed:
connector = aiohttp.TCPConnector(
limit_per_host=_CONNECTOR_LIMIT_PER_HOST,
ssl=ssl_context(),
)
_session = aiohttp.ClientSession(connector=connector)
return _session
def parse_content_length(value: Optional[str]) -> Optional[int]:
"""Parse a byte-count header value, or None if absent/malformed/negative."""
if not value:
return None
try:
n = int(value)
except ValueError:
return None
return n if n >= 0 else None

View File

@ -0,0 +1,111 @@
"""Path resolution for model downloads.
Model identifiers used across the download API are *relative destination
paths* of the form ``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``).
This module turns one of those identifiers into an absolute on-disk path
under one of ComfyUI's registered model folders, while rejecting unknown
folders, path traversal, and other ill-formed inputs.
"""
import os
import re
from typing import Optional, Tuple
import folder_paths
# Constrain components so a model_id can never escape its target directory.
# - directory: a single path segment of safe chars
# - filename: a single path segment of safe chars, must end with a model extension
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$")
# Destination filename must name a model file (same set as the URL allowlist),
# so a download can't land as e.g. ``foo.txt`` that ComfyUI won't recognise.
_MODEL_EXTENSIONS = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
# Distinctive temp suffix so the startup orphan-sweep only removes files THIS
# subsystem created — never unrelated ``*.tmp`` files in the model dirs.
_TMP_SUFFIX = ".comfy-download.tmp"
class InvalidModelId(ValueError):
"""Raised when a model_id is syntactically invalid or refers to an
unknown model folder."""
def parse_model_id(model_id: str) -> Tuple[str, str]:
"""Split ``<directory>/<filename>`` and validate both components.
Returns ``(directory, filename)``. Raises ``InvalidModelId`` on
malformed input. Does NOT touch the filesystem.
"""
if not isinstance(model_id, str) or "/" not in model_id:
raise InvalidModelId(f"model_id must be '<directory>/<filename>', got {model_id!r}")
directory, _, filename = model_id.partition("/")
if "/" in filename or not directory or not filename:
raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}")
if not _SEGMENT_RE.match(directory):
raise InvalidModelId(f"invalid directory segment {directory!r}")
if not _SEGMENT_RE.match(filename):
raise InvalidModelId(f"invalid filename segment {filename!r}")
if not filename.endswith(_MODEL_EXTENSIONS):
raise InvalidModelId(
f"filename must end with a model extension {_MODEL_EXTENSIONS}, got {filename!r}"
)
if directory not in folder_paths.folder_names_and_paths:
raise InvalidModelId(f"unknown model folder {directory!r}")
return directory, filename
def resolve_existing(model_id: str) -> Optional[str]:
"""Return the absolute path of an installed model, or None if missing.
Honours ``extra_model_paths.yaml`` transparently via
``folder_paths.get_full_path``.
"""
directory, filename = parse_model_id(model_id)
return folder_paths.get_full_path(directory, filename)
def resolve_destination(model_id: str, epoch: int = 0) -> Tuple[str, str]:
"""Return ``(final_path, tmp_path)`` for a download.
Downloads land at the first registered path for the model's directory
(the "primary" location). The temp sibling is the write target, atomically
renamed onto ``final_path`` on success.
``tmp_path`` embeds the session ``epoch`` so a cancel+retry of the same
model never shares a temp path between the old (cancelling) worker and the
new attempt otherwise the old worker's rollback could delete the new
worker's in-progress file. The distinctive suffix scopes the orphan sweep.
"""
directory, filename = parse_model_id(model_id)
roots = folder_paths.get_folder_paths(directory)
if not roots:
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
root = roots[0]
final_path = os.path.join(root, filename)
tmp_path = f"{final_path}.{epoch}{_TMP_SUFFIX}"
return final_path, tmp_path
def iter_all_tmp_paths():
"""Yield this subsystem's temp files under every registered model folder.
Matches only our distinctive ``_TMP_SUFFIX`` (not every ``*.tmp``) so the
startup orphan-sweep can't delete temp files created by other tools.
"""
seen_roots: set[str] = set()
for directory in folder_paths.folder_names_and_paths.keys():
for root in folder_paths.get_folder_paths(directory):
if root in seen_roots or not os.path.isdir(root):
continue
seen_roots.add(root)
try:
for entry in os.scandir(root):
if entry.is_file() and entry.name.endswith(_TMP_SUFFIX):
yield entry.path
except OSError:
# Folder might be unreadable / missing on certain mounts —
# not fatal, just skip it.
continue

View File

@ -0,0 +1,708 @@
"""Unit tests for the HuggingFace auth subsystem.
Covers:
- token store: save/load roundtrip, chmod 0600, atomic write, delete
- eligibility under various CLI-arg combinations
- URL parsing (huggingface.co host detection + repo_id extraction)
- HF-aware gated_detection.probe_url (mocked auth_check)
- HF auth routes (token status, login start with eligibility gate, logout)
- PKCE primitives + authorize URL shape
The OAuth callback server itself isn't exercised end-to-end here — that
requires a real HF server. We test the components (state checking,
URL building, code-exchange request shape) instead.
"""
from __future__ import annotations
import os
import stat
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from app.model_downloader.api.routes import register_routes
from app.model_downloader.hf_auth import oauth
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore
from app.model_downloader.hf_auth.token_store import (
EXPIRY_BUFFER_SECS,
Token,
delete_token,
load_token,
save_token,
)
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def patched_user_dir(tmp_path):
"""Redirect ``folder_paths.get_user_directory`` so the token file
lands in an isolated tmp_path instead of the real user dir."""
user_dir = tmp_path / "user"
user_dir.mkdir()
with patch("folder_paths.get_user_directory", return_value=str(user_dir)):
yield user_dir
def _token_file_path(user_dir) -> str:
return os.path.join(user_dir, "__hf_auth", "hf_auth_token.json")
@pytest.fixture
def fresh_auth_store():
"""Wipe singleton state between tests: auth + probe caches."""
from app.model_downloader import gated_detection
HF_AUTH_STORE._token = None
HF_AUTH_STORE._loaded_from_disk = False
gated_detection.clear_caches_for_tests()
yield HF_AUTH_STORE
HF_AUTH_STORE._token = None
HF_AUTH_STORE._loaded_from_disk = False
gated_detection.clear_caches_for_tests()
@pytest.fixture
def app(patched_user_dir, fresh_auth_store):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# URL parsing
# --------------------------------------------------------------------------- #
def test_is_hf_url_recognises_huggingface_co():
assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors")
assert is_hf_url("https://huggingface.co/abc")
assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors")
assert not is_hf_url("https://civitai.com/x.safetensors")
def test_repo_id_from_url_extracts_org_and_repo():
url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors"
assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR"
def test_repo_id_from_url_handles_nested_path():
url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors"
assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3"
def test_repo_id_from_url_returns_none_for_non_hf():
assert repo_id_from_url("https://civitai.com/x.safetensors") is None
def test_repo_id_from_url_returns_none_for_non_resolve_paths():
assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None
assert repo_id_from_url("https://huggingface.co/org") is None
# --------------------------------------------------------------------------- #
# Token store
# --------------------------------------------------------------------------- #
def test_token_store_roundtrip(patched_user_dir):
tok = Token(
access_token="hf_abc",
refresh_token="rf_def",
expires_at=9999999999.0,
scope="openid profile",
)
save_token(tok)
loaded = load_token()
assert loaded == tok
def test_token_store_writes_0600(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
save_token(tok)
path = _token_file_path(patched_user_dir)
mode = stat.S_IMODE(os.stat(path).st_mode)
# On Windows we silently no-op chmod; allow either the intended
# mode or whatever umask the OS gave us.
if os.name == "posix":
assert mode == 0o600
def test_token_store_delete_removes_file(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
save_token(tok)
delete_token()
path = _token_file_path(patched_user_dir)
assert not os.path.exists(path)
# Idempotent: second delete is fine.
delete_token()
def test_token_store_load_returns_none_for_missing_file(patched_user_dir):
assert load_token() is None
def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir):
path = _token_file_path(patched_user_dir)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
f.write("not json {")
assert load_token() is None
def test_token_is_valid_uses_buffer(patched_user_dir):
import time
fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
nearly_expired = Token(
access_token="x",
refresh_token=None,
expires_at=time.time() + EXPIRY_BUFFER_SECS - 1,
)
assert fresh.is_valid()
assert not nearly_expired.is_valid()
def test_token_is_valid_rejects_empty_access_token():
import time
tok = Token(access_token="", refresh_token=None, expires_at=time.time() + 3600)
assert not tok.is_valid()
def test_token_is_valid_rejects_at_exact_buffer_boundary():
import time
tok = Token(
access_token="x",
refresh_token=None,
expires_at=time.time() + EXPIRY_BUFFER_SECS,
)
assert not tok.is_valid()
# --------------------------------------------------------------------------- #
# Auth store
# --------------------------------------------------------------------------- #
def test_auth_store_loads_lazily(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
save_token(tok)
store = HfAuthStore()
assert store.has_token()
assert store.get_token_sync() == tok
def test_auth_store_set_persists(patched_user_dir):
store = HfAuthStore()
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
store.set_token(tok)
# Token is on disk now — a fresh store sees it.
assert HfAuthStore().get_token_sync() == tok
def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir):
store = HfAuthStore()
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
store.set_token(tok)
store.clear()
assert not store.has_token()
assert HfAuthStore().get_token_sync() is None
def test_auth_store_has_token_true_when_expired_but_refreshable(patched_user_dir):
import time
store = HfAuthStore()
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
assert store.has_token()
assert not expired.is_valid()
def test_auth_store_get_token_sync_returns_expired_without_refresh(patched_user_dir):
import time
store = HfAuthStore()
expired = Token(
access_token="old",
refresh_token=None,
expires_at=time.time() - 100,
)
store.set_token(expired)
assert store.get_token_sync() == expired
@pytest.mark.asyncio
async def test_auth_store_get_valid_returns_none_when_expired_without_refresh(
patched_user_dir,
):
import time
store = HfAuthStore()
expired = Token(
access_token="old",
refresh_token=None,
expires_at=time.time() - 100,
)
store.set_token(expired)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(),
) as refresh_mock:
result = await store.get_valid_token()
assert result is None
refresh_mock.assert_not_called()
@pytest.mark.asyncio
async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir):
store = HfAuthStore()
import time
tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
store.set_token(tok)
fetched = await store.get_valid_token()
assert fetched == tok
@pytest.mark.asyncio
async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir):
store = HfAuthStore()
import time
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
refreshed = Token(
access_token="new",
refresh_token="rf",
expires_at=time.time() + 3600,
)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(return_value=refreshed),
):
result = await store.get_valid_token()
assert result == refreshed
@pytest.mark.asyncio
async def test_auth_store_get_valid_token_does_not_resurrect_after_logout(
patched_user_dir,
):
"""A logout landing *during* an in-flight refresh must not be undone by
the refresh writing the token back (the resurrection race)."""
store = HfAuthStore()
import time
expired = Token(
access_token="old", refresh_token="rf", expires_at=time.time() - 100
)
store.set_token(expired)
refreshed = Token(
access_token="new", refresh_token="rf", expires_at=time.time() + 3600
)
async def fake_refresh(_refresh_token):
# Simulate the user clicking "Log out" while the refresh is in flight.
store.clear()
return refreshed
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=fake_refresh,
):
result = await store.get_valid_token()
# The refresh result is discarded — logout wins, in memory and on disk.
assert result is None
assert not store.has_token()
assert load_token() is None
@pytest.mark.asyncio
async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir):
store = HfAuthStore()
import time
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(side_effect=RuntimeError("HF down")),
):
result = await store.get_valid_token()
assert result is None
# --------------------------------------------------------------------------- #
# Eligibility
# --------------------------------------------------------------------------- #
@pytest.mark.parametrize(
"listen,multi_user,expected",
[
("127.0.0.1", False, True),
("127.0.0.1", True, False), # multi-user disables it
("0.0.0.0", False, False), # bind-all is not loopback
("0.0.0.0", True, False),
("192.168.1.5", False, False), # LAN address
("::1", False, True), # IPv6 loopback
],
)
def test_eligibility(listen, multi_user, expected, monkeypatch):
from app.model_downloader.hf_auth import eligibility
from comfy.cli_args import args
monkeypatch.setattr(args, "listen", listen)
monkeypatch.setattr(args, "multi_user", multi_user)
assert eligibility.is_hf_auth_eligible() is expected
# --------------------------------------------------------------------------- #
# gated_detection HF probe
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_probe_url_hf_public(fresh_auth_store):
"""auth_check succeeds with no token → is_hf_downloadable = True."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch("app.model_downloader.gated_detection._auth_check_sync"), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=1024),
):
result = await probe_url(url)
assert result.is_hf_downloadable is True
assert result.file_size == 1024
@pytest.mark.asyncio
async def test_probe_url_hf_gated_no_access(fresh_auth_store):
"""auth_check raises GatedRepoError → is_hf_downloadable = False."""
from huggingface_hub.errors import GatedRepoError
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
fake_response = MagicMock(status_code=403)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=GatedRepoError("gated", response=fake_response),
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
):
result = await probe_url(url)
assert result.is_hf_downloadable is False
@pytest.mark.asyncio
async def test_probe_url_non_hf_skips_auth_check():
"""Non-HF URLs never call auth_check; is_hf_downloadable stays None."""
from app.model_downloader.gated_detection import probe_url
url = "https://civitai.com/api/download/models/1.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=2048),
):
result = await probe_url(url)
assert result.is_hf_downloadable is None
assert result.file_size == 2048
mocked.assert_not_called()
@pytest.mark.asyncio
async def test_is_gated_cached_across_calls(fresh_auth_store):
"""Intrinsic ``is_gated`` should be determined exactly once per URL.
Subsequent ``probe_url`` calls for the same URL must not re-issue
the null-token auth_check that's the whole point of the cache."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync"
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=1024),
):
await probe_url(url)
await probe_url(url)
await probe_url(url)
# Three probe_url calls × public-only-needs-1-auth_check = 1 call total.
assert mocked.call_count == 1
@pytest.mark.asyncio
async def test_file_size_cached_across_calls(fresh_auth_store):
"""Once a successful HEAD lands, subsequent calls don't re-HEAD."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync"
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=2048),
) as size_probe:
r1 = await probe_url(url)
r2 = await probe_url(url)
assert r1.file_size == 2048
assert r2.file_size == 2048
assert size_probe.call_count == 1
@pytest.mark.asyncio
async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store):
"""When ``is_hf_downloadable`` is False we must NOT HEAD the URL —
otherwise a 401-due-to-gating would land as a cached ``None`` that
survives a later successful login."""
from app.model_downloader.gated_detection import probe_url
from huggingface_hub.errors import GatedRepoError
from unittest.mock import MagicMock
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
fake_resp = MagicMock(status_code=403)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=GatedRepoError("gated", response=fake_resp),
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
) as size_probe:
result = await probe_url(url)
assert result.is_hf_downloadable is False
assert result.file_size is None
assert size_probe.call_count == 0
@pytest.mark.asyncio
async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir):
"""For a gated URL, auth_check runs twice: once with token=None to
determine the intrinsic ``is_gated`` flag (cached forever), and once
with the stored access_token to determine ``is_hf_downloadable`` for
the current user."""
from app.model_downloader import gated_detection
from app.model_downloader.gated_detection import probe_url
from huggingface_hub.errors import GatedRepoError
from unittest.mock import MagicMock
gated_detection.clear_caches_for_tests()
fresh_auth_store.set_token(Token(
access_token="hf_test_token",
refresh_token=None,
expires_at=9999999999.0,
))
url = "https://huggingface.co/private/repo/resolve/main/x.safetensors"
fake_resp = MagicMock(status_code=403)
def fake_auth_check(repo_id, token):
# Null-token call → repo is gated. Subsequent call with the real
# token succeeds (user has access).
if token is None:
raise GatedRepoError("gated", response=fake_resp)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=fake_auth_check,
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
):
result = await probe_url(url)
# is_hf_downloadable should be True (token-authed call succeeded).
assert result.is_hf_downloadable is True
# Two calls: (repo_id, None) then (repo_id, <token>).
assert mocked.call_count == 2
assert mocked.call_args_list[0].args == ("private/repo", None)
assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token")
# --------------------------------------------------------------------------- #
# OAuth primitives
# --------------------------------------------------------------------------- #
def test_make_pkce_returns_distinct_high_entropy_values():
verifier1, challenge1, state1 = oauth._make_pkce()
verifier2, challenge2, state2 = oauth._make_pkce()
assert verifier1 != verifier2
assert challenge1 != challenge2
assert state1 != state2
# Verifier should be at least 43 chars per PKCE spec.
assert len(verifier1) >= 43
def test_build_authorize_url_includes_pkce_and_state():
url = oauth._build_authorize_url("challenge123", "state456")
assert url.startswith(oauth.AUTHORIZE_URL)
assert "client_id=" + oauth.HF_CLIENT_ID in url
assert "code_challenge=challenge123" in url
assert "code_challenge_method=S256" in url
assert "state=state456" in url
assert "response_type=code" in url
# --------------------------------------------------------------------------- #
# Routes
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_hf_auth_token_status_empty(aiohttp_client, app):
"""No token set → token_available=false, username=null."""
client = await aiohttp_client(app)
resp = await client.get("/api/hf-auth-token-status")
assert resp.status == 200
data = await resp.json()
assert data == {"token_available": False, "username": None}
@pytest.mark.asyncio
async def test_hf_auth_token_status_with_token(
aiohttp_client, app, fresh_auth_store, patched_user_dir
):
"""Token present, whoami works → username is returned."""
fresh_auth_store.set_token(Token(
access_token="x", refresh_token=None, expires_at=9999999999.0,
))
with patch(
"app.model_downloader.api.routes._whoami_username",
return_value="alice",
):
client = await aiohttp_client(app)
resp = await client.get("/api/hf-auth-token-status")
assert resp.status == 200
assert (await resp.json()) == {"token_available": True, "username": "alice"}
@pytest.mark.asyncio
async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch):
"""Not loopback / multi-user → 403."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: False,
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 403
assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE"
@pytest.mark.asyncio
async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch):
"""Eligible + first attempt → 200 with authorize_url."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 200
assert (await resp.json())["authorize_url"].startswith(
"https://huggingface.co/oauth/authorize"
)
@pytest.mark.asyncio
async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch):
"""Lock already held → 409."""
from app.model_downloader.hf_auth.oauth import OAuthInProgressError
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(side_effect=OAuthInProgressError()),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS"
@pytest.mark.asyncio
async def test_hf_auth_login_start_503_when_callback_bind_fails(
aiohttp_client, app, monkeypatch
):
"""Callback server failed to bind (e.g. port busy) → 503, not a dead URL."""
from app.model_downloader.hf_auth.oauth import OAuthCallbackError
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(side_effect=OAuthCallbackError("could not bind callback port")),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 503
assert (await resp.json())["error"]["code"] == "HF_AUTH_START_FAILED"
@pytest.mark.asyncio
async def test_hf_auth_logout_clears_store(
aiohttp_client, app, fresh_auth_store, patched_user_dir
):
fresh_auth_store.set_token(Token(
access_token="x", refresh_token=None, expires_at=9999999999.0,
))
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-logout")
assert resp.status == 200
assert (await resp.json()) == {"logged_out": True}
assert not fresh_auth_store.has_token()
@pytest.mark.asyncio
async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch):
"""The availability response embeds {token_available, eligible}."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={"models": {}},
)
assert resp.status == 200
data = await resp.json()
assert "hf_auth" in data
assert data["hf_auth"] == {"token_available": False, "eligible": True}

View File

@ -0,0 +1,514 @@
"""Unit tests for the server-side model download subsystem.
Covers the pieces that don't require talking to a real network:
- path parsing & allowlist (pure functions)
- DownloadServer registry lifecycle (in-memory state)
- API routes via aiohttp_client + folder_paths/probe_url patches
Streaming downloads themselves are exercised indirectly the route-level
tests stub out the network probe so we can verify the gating logic in
``download_models`` without making real HTTP calls.
"""
from __future__ import annotations
import asyncio
from unittest.mock import patch, AsyncMock
import pytest
from aiohttp import web
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.api.routes import register_routes
from app.model_downloader.download_server import DownloadServer
from app.model_downloader.gated_detection import MetadataProbeResult
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_destination,
resolve_existing,
)
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def model_root(tmp_path):
"""A fake ``models/`` root with two registered folder types."""
loras_dir = tmp_path / "loras"
checkpoints_dir = tmp_path / "checkpoints"
loras_dir.mkdir()
checkpoints_dir.mkdir()
return tmp_path, loras_dir, checkpoints_dir
@pytest.fixture
def patched_folder_paths(model_root):
"""Point folder_paths at our fake roots for the duration of one test."""
_root, loras_dir, checkpoints_dir = model_root
mapping = {
"loras": ([str(loras_dir)], {".safetensors"}),
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
}
with patch(
"folder_paths.folder_names_and_paths", mapping
), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
yield mapping
@pytest.fixture
def fresh_download_server():
"""Reset the module-level singleton between tests so registry state
doesn't leak across tests sharing the singleton."""
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
yield DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
@pytest.fixture
def app(patched_folder_paths, fresh_download_server):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# Pure helpers: allowlist + path parsing
# --------------------------------------------------------------------------- #
def test_allowlist_accepts_hf_safetensors():
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
def test_allowlist_accepts_civitai_pth():
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
def test_allowlist_rejects_unknown_host():
assert not is_url_allowed("https://example.com/x.safetensors")
def test_allowlist_rejects_api_path_on_hf():
# On an allowlisted host but not pointing at a model file.
assert not is_url_allowed("https://huggingface.co/api/models")
def test_allowlist_rejects_non_https_except_localhost():
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
assert is_url_allowed("http://localhost:8000/x.safetensors")
def test_parse_model_id_valid(patched_folder_paths):
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
def test_parse_model_id_rejects_traversal(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("../etc/passwd")
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("nope/x.safetensors")
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("loras/sub/x.safetensors")
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
target = loras_dir / "foo.safetensors"
target.write_bytes(b"x")
assert resolve_existing("loras/foo.safetensors") == str(target)
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
assert resolve_existing("loras/missing.safetensors") is None
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
final, tmp = resolve_destination("loras/foo.safetensors", epoch=7)
assert final == str(loras_dir / "foo.safetensors")
# Temp path embeds the session epoch (so cancel+retry can't collide on it)
# and uses the subsystem-specific suffix the startup sweep matches.
assert tmp == f"{final}.7.comfy-download.tmp"
# --------------------------------------------------------------------------- #
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
# --------------------------------------------------------------------------- #
def test_register_is_exclusive():
server = DownloadServer()
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
assert s1 is not None
assert s2 is None
assert server.is_downloading("loras/x.safetensors")
def test_cancel_removes_session():
server = DownloadServer()
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
assert server.cancel("loras/x.safetensors") is True
assert not server.is_downloading("loras/x.safetensors")
def test_cancel_returns_false_when_absent():
server = DownloadServer()
assert server.cancel("loras/never.safetensors") is False
def test_finish_only_clears_matching_epoch():
"""If a session is cancelled and a new one for the same id is
registered, ``finish`` from the original worker must not evict the
newer session."""
server = DownloadServer()
s_old = server.try_register("loras/x.safetensors", "u1")
server.cancel("loras/x.safetensors")
s_new = server.try_register("loras/x.safetensors", "u2")
assert s_new is not None and s_new.epoch != s_old.epoch
# Old worker's late finish() is a no-op:
server.finish(s_old)
assert server.is_downloading("loras/x.safetensors")
server.finish(s_new)
assert not server.is_downloading("loras/x.safetensors")
def test_is_active_follows_cancellation():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
assert server.is_active(s)
server.cancel("loras/x.safetensors")
assert not server.is_active(s)
def test_update_progress_tracks_fraction():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, 100)
snap = server.snapshot()["loras/x.safetensors"]
assert snap.bytes_downloaded == 50
assert snap.total_bytes == 100
assert snap.progress == 0.5
def test_update_progress_with_unknown_total_keeps_progress_none():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, None)
assert server.snapshot()["loras/x.safetensors"].progress is None
def test_cleanup_orphan_tmp_files(model_root):
"""Orphan temp left by a crashed download must be swept on first use,
while unrelated *.tmp files in the model dir are left untouched."""
_root, loras_dir, _ = model_root
orphan = loras_dir / "stale.safetensors.3.comfy-download.tmp"
orphan.write_bytes(b"partial")
unrelated = loras_dir / "someothertool.tmp"
unrelated.write_bytes(b"not ours")
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
with patch("folder_paths.folder_names_and_paths", mapping), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
server = DownloadServer()
assert orphan.exists(), "sweep must not run at construction time"
server.sweep_orphan_tmp_files()
assert not orphan.exists()
assert unrelated.exists(), "unrelated .tmp must not be swept"
# Idempotent — a second call is a cheap no-op.
server.sweep_orphan_tmp_files()
# --------------------------------------------------------------------------- #
# Route: POST /api/models-availability-status
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_availability_partitions_correctly(
aiohttp_client, app, model_root, fresh_download_server
):
_root, loras_dir, _ = model_root
(loras_dir / "present.safetensors").write_bytes(b"x")
fresh_download_server.try_register(
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
)
client = await aiohttp_client(app)
# Stub probes — we're testing state assignment, not network calls.
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
body = {
"models": {
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
}
}
resp = await client.post("/api/models-availability-status", json=body)
assert resp.status == 200
data = await resp.json()
models = data["models"]
assert models["loras/present.safetensors"]["state"] == "available"
assert models["loras/missing.safetensors"]["state"] == "missing"
assert models["loras/inflight.safetensors"]["state"] == "downloading"
assert "hf_auth" in data
@pytest.mark.asyncio
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
resp = await client.post(
"/api/models-availability-status",
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
)
assert resp.status == 200
data = await resp.json()
assert data["models"]["../etc/passwd"]["state"] == "missing"
# --------------------------------------------------------------------------- #
# Route: POST /api/download-models — precondition gating
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
)
assert resp.status == 400
err = (await resp.json())["error"]
assert err["code"] == "URL_NOT_ALLOWED"
@pytest.mark.asyncio
async def test_download_rejects_already_available(
aiohttp_client, app, model_root
):
_root, loras_dir, _ = model_root
(loras_dir / "x.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
@pytest.mark.asyncio
async def test_download_rejects_already_downloading(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
@pytest.mark.asyncio
async def test_download_rejects_gated_model(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
):
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
@pytest.mark.asyncio
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
@pytest.mark.asyncio
async def test_download_atomic_failure_does_not_register_partial(
aiohttp_client, app, model_root, fresh_download_server
):
"""If one model in a batch fails, none get registered."""
_root, loras_dir, _ = model_root
(loras_dir / "already.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={
"models": {
"loras/already.safetensors":
"https://huggingface.co/a/b/resolve/main/already.safetensors",
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors",
}
},
)
assert resp.status == 409
# The "new" model should not have been registered as part of the
# failed batch.
assert not fresh_download_server.is_downloading("loras/new.safetensors")
@pytest.mark.asyncio
async def test_download_schedules_when_all_preconditions_pass(
aiohttp_client, app, fresh_download_server
):
"""Verify the precondition pass, registration pass, and async
scheduling all wire up correctly. We patch the streamer to avoid
real HTTP while still letting the route execute end-to-end."""
started = asyncio.Event()
finish_signal = asyncio.Event()
async def fake_stream(session):
started.set()
await finish_signal.wait()
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.finish(session)
return "/dev/null"
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
), patch(
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors"
}},
)
assert resp.status == 202
body = await resp.json()
assert body["accepted"] is True
assert body["scheduled"] == ["loras/new.safetensors"]
# Wait for the worker to actually start.
await asyncio.wait_for(started.wait(), timeout=2.0)
assert fresh_download_server.is_downloading("loras/new.safetensors")
finish_signal.set()
# --------------------------------------------------------------------------- #
# Route: POST /api/cancel-model-download-session
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_cancel_removes_active_session(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/x.safetensors"},
)
assert resp.status == 200
assert (await resp.json())["cancelled"] is True
assert not fresh_download_server.is_downloading("loras/x.safetensors")
@pytest.mark.asyncio
async def test_cancel_returns_404_when_none(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/nothing.safetensors"},
)
assert resp.status == 404
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
# --------------------------------------------------------------------------- #
# Unified availability response embeds metadata per id
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio
async def test_availability_embeds_metadata(aiohttp_client, app):
"""``file_size`` + ``is_hf_downloadable`` come back on the same
request as the state no separate metadata endpoint."""
results = {
"https://huggingface.co/a/b/resolve/main/free.safetensors":
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
}
async def fake_probe(url):
return results[url]
with patch(
"app.model_downloader.api.routes.probe_url", new=fake_probe
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={
"models": {
"loras/free.safetensors":
"https://huggingface.co/a/b/resolve/main/free.safetensors",
"loras/gated.safetensors":
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
}
},
)
assert resp.status == 200
models = (await resp.json())["models"]
assert models["loras/free.safetensors"]["file_size"] == 1024
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
assert models["loras/gated.safetensors"]["file_size"] is None
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False