This commit is contained in:
DoronGenzelHass 2026-06-22 15:30:27 +03:00 committed by GitHub
commit 58c5cef538
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 4521 additions and 0 deletions

View File

@ -0,0 +1,46 @@
"""URL allowlist for server-side model fetches.
Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows
agree on which URLs are eligible for download. Server-side allowlisting is
the primary SSRF defense for this subsystem workflow JSON is untrusted
input (anyone can hand-craft one), so we never let the server fetch URLs
outside this list.
"""
from urllib.parse import urlparse
# Frontend parity: ``missingModelDownload-*.js`` exports the same triple as
# ``i = [...]`` (Civitai / HuggingFace / localhost).
_ALLOWED_URL_PREFIXES = (
"https://huggingface.co/",
"https://civitai.com/",
"http://localhost:",
"http://127.0.0.1:",
)
# Frontend parity: same set as ``a = [...]`` in the bundle.
_ALLOWED_MODEL_EXTENSIONS = (
".safetensors",
".sft",
".ckpt",
".pth",
".pt",
)
def is_url_allowed(url: str) -> bool:
"""Check whether ``url`` is permitted as a server-side download source.
Returns True only when both:
- the URL starts with one of the allowed prefixes, AND
- the URL's final path segment ends with a known model extension.
Both checks are required to keep arbitrary HTML / API endpoints on
allowlisted hosts (e.g. ``https://huggingface.co/api/...``) off the table.
"""
if not isinstance(url, str) or not url:
return False
if not any(url.startswith(p) for p in _ALLOWED_URL_PREFIXES):
return False
path = urlparse(url).path
return any(path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS)

View File

@ -0,0 +1,332 @@
"""Aiohttp routes for the server-side model download subsystem.
Endpoint surface (all under ``/api/``, all kebab-case):
- ``POST /api/models-availability-status`` bulk status + metadata query.
- ``POST /api/download-models`` start a batch of downloads.
- ``POST /api/cancel-model-download-session`` cancel a single in-flight one.
- ``GET /api/hf-auth-token-status`` current HF login state.
- ``POST /api/hf-auth-login-start`` begin the HF OAuth flow.
- ``POST /api/hf-auth-logout`` drop the stored HF token.
The contract is intentionally narrow: only model_ids of the form
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
are accepted, and only URLs on the same allowlist the frontend already
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
to keep the server out of the SSRF business for this feature.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Optional
from aiohttp import web
from pydantic import BaseModel, ValidationError
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadSession,
)
from app.model_downloader.downloader import schedule_batch
from app.model_downloader.gated_detection import probe_url
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
from app.model_downloader.hf_auth.oauth import (
OAuthInProgressError,
start_login_flow,
)
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_existing,
)
from app.model_downloader.api import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
def register_routes(app: web.Application) -> None:
"""Wire the model-downloader routes into the running aiohttp app.
Called once from ``server.py`` during ``PromptServer`` startup.
"""
app.add_routes(ROUTES)
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _validation_error(code: str, ve: ValidationError) -> web.Response:
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
return web.json_response(
payload.model_dump(mode="json", exclude_none=False),
status=status,
)
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
"""Parse a JSON body into a pydantic model or raise a 400 response."""
try:
raw = await request.json()
except json.JSONDecodeError:
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
return model.model_validate(raw)
except ValidationError as ve:
return _validation_error("INVALID_BODY", ve)
# ----- 1. availability status (unified: state + metadata per id) -----
@ROUTES.post("/api/models-availability-status")
async def models_availability_status(request: web.Request) -> web.Response:
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
State (``available`` / ``missing`` / ``downloading``) is cheap to
recompute per call. ``file_size`` and ``is_gated`` are cached
server-side per URL. ``is_hf_downloadable`` is recomputed every
call from the current token state that's what makes login + license
acceptance show up in the UI within one poll cycle without any
frontend cache plumbing.
"""
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
if isinstance(parsed, web.Response):
return parsed
items = list(parsed.models.items())
# Run all probes concurrently; each is internally cached per URL.
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
for (model_id, _url), probe in zip(items, probes):
try:
parse_model_id(model_id)
except InvalidModelId:
# Ill-formed identifier: report as missing without 400-ing the
# whole batch — the workflow author probably typo'd.
response_models[model_id] = schemas_out.ModelStatusEntry(
state="missing",
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
active = DOWNLOAD_SERVER.get(model_id)
if active is not None:
response_models[model_id] = schemas_out.ModelStatusEntry(
state="downloading",
progress=schemas_out.DownloadProgress(
bytes_downloaded=active.bytes_downloaded,
total_bytes=active.total_bytes,
progress=active.progress,
),
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
state: schemas_out.ModelState = (
"available" if resolve_existing(model_id) is not None else "missing"
)
response_models[model_id] = schemas_out.ModelStatusEntry(
state=state,
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
return _ok(schemas_out.AvailabilityStatusResponse(
models=response_models,
hf_auth=schemas_out.HfAuthStatus(
token_available=HF_AUTH_STORE.has_token(),
eligible=is_hf_auth_eligible(),
),
))
# ----- 2. start downloads -----
@ROUTES.post("/api/download-models")
async def download_models(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
if isinstance(parsed, web.Response):
return parsed
if not parsed.models:
return _error(400, "EMPTY_REQUEST", "No models supplied.")
# ----- precondition pass: validate everything BEFORE registering anything -----
# Atomic semantics: if any model fails any precondition (invalid id,
# not allow-listed URL, already on disk, already downloading, or gated),
# the entire request fails and no state is changed.
requested = list(parsed.models.items())
for model_id, url in requested:
try:
parse_model_id(model_id)
except InvalidModelId as e:
return _error(400, "INVALID_MODEL_ID", str(e),
{"model_id": model_id})
if not is_url_allowed(url):
return _error(
400, "URL_NOT_ALLOWED",
"Server-side downloads only accept HuggingFace, Civitai, "
"or localhost URLs ending in a known model extension.",
{"model_id": model_id, "url": url},
)
if resolve_existing(model_id) is not None:
return _error(409, "ALREADY_AVAILABLE",
f"Model already exists on disk: {model_id}",
{"model_id": model_id})
if DOWNLOAD_SERVER.is_downloading(model_id):
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress.",
{"model_id": model_id})
# Reachability check last — it's the only one that talks to the
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
# reflects current token access; for non-HF URLs it's None, and we
# treat that as "no info, proceed".
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
for (model_id, url), probe in zip(requested, probes):
if probe.is_hf_downloadable is False:
return _error(
400, "MODEL_NOT_DOWNLOADABLE",
f"Model {model_id} is gated on HuggingFace and the current "
f"server token (if any) does not grant access.",
{"model_id": model_id, "url": url},
)
# ----- registration pass: try_register is atomic per model_id -----
# Defensive: another request might have raced past our pre-check
# between the loop above and here. try_register handles that.
sessions: list[DownloadSession] = []
for model_id, url in requested:
session = DOWNLOAD_SERVER.try_register(model_id, url)
if session is None:
# Race: someone else got in. Roll back what we registered.
for s in sessions:
DOWNLOAD_SERVER.cancel(s.model_id)
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress (race).",
{"model_id": model_id})
sessions.append(session)
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
schedule_batch(sessions)
logging.info(
"[model_downloader] scheduled %d downloads: %s",
len(sessions), [s.model_id for s in sessions],
)
return _ok(schemas_out.DownloadModelsResponse(
accepted=True,
scheduled=[s.model_id for s in sessions],
), status=202)
# ----- 3. cancel a session -----
@ROUTES.post("/api/cancel-model-download-session")
async def cancel_model_download_session(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
if isinstance(parsed, web.Response):
return parsed
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
if not cancelled:
return _error(404, "NOT_DOWNLOADING",
f"No active download for {parsed.model_id}.",
{"model_id": parsed.model_id})
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
# ----- 4. HuggingFace OAuth status / login start / logout -----
@ROUTES.get("/api/hf-auth-token-status")
async def hf_auth_token_status(request: web.Request) -> web.Response:
"""Return whether the server holds a usable HF token + its username.
Used by the settings UI and (out-of-band) by the frontend on
login completion. ``token_available`` is true even if the cached
access_token is expired as long as a refresh_token exists, the
user is "logged in" from their perspective.
"""
token_present = HF_AUTH_STORE.has_token()
username: Optional[str] = None
if token_present:
# Resolve the username via whoami. Done in a worker thread because
# huggingface_hub's whoami is synchronous + blocks on a network call.
tok = await HF_AUTH_STORE.get_valid_token()
if tok is not None:
try:
username = await asyncio.to_thread(_whoami_username, tok.access_token)
except Exception as e:
logging.debug("[hf_auth] whoami failed: %s", e)
return _ok(schemas_out.HfAuthTokenStatusResponse(
token_available=token_present,
username=username,
))
def _whoami_username(token: str) -> Optional[str]:
"""Sync helper: ask HF for the user name attached to a token."""
from huggingface_hub import HfApi
info = HfApi().whoami(token=token)
if isinstance(info, dict):
return info.get("name") or info.get("fullname")
return None
@ROUTES.post("/api/hf-auth-login-start")
async def hf_auth_login_start(request: web.Request) -> web.Response:
"""Begin one OAuth attempt: bind the callback port, return the URL.
Rejected outright if this deployment isn't eligible (we don't
surface the option on multi-tenant / public-IP installs).
"""
if not is_hf_auth_eligible():
return _error(
403, "HF_AUTH_NOT_ELIGIBLE",
"This server is not eligible for interactive HuggingFace login. "
"It must be bound to a loopback address and not running in "
"--multi-user mode.",
)
try:
url = await start_login_flow()
except OAuthInProgressError:
return _error(
409, "HF_AUTH_IN_PROGRESS",
"Another HuggingFace login attempt is in progress. Try again "
"after it completes or times out.",
)
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
@ROUTES.post("/api/hf-auth-logout")
async def hf_auth_logout(request: web.Request) -> web.Response:
"""Drop the in-memory + on-disk HF token."""
HF_AUTH_STORE.clear()
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))

View File

@ -0,0 +1,41 @@
"""Request schemas for the model-downloader API.
Each endpoint accepts a small JSON body. Pydantic enforces the shape at
the boundary; route handlers operate only on validated values past that.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class AvailabilityStatusRequest(BaseModel):
"""``POST /api/models-availability-status``.
Sent by the frontend on each poll. Each entry is ``{model_id: url}``;
the URL is the one declared in ``properties.models[i].url`` in the
workflow JSON and lets the server compute per-id metadata
(``file_size`` + ``is_hf_downloadable``) on the same request.
"""
models: dict[str, str] = Field(default_factory=dict)
class DownloadModelsRequest(BaseModel):
"""``POST /api/download-models``.
Same shape as the metadata request the URL for each model_id.
Returns immediately after validation and scheduling.
"""
models: dict[str, str] = Field(default_factory=dict)
class CancelDownloadSessionRequest(BaseModel):
"""``POST /api/cancel-model-download-session``."""
model_id: str
__all__ = [
"AvailabilityStatusRequest",
"DownloadModelsRequest",
"CancelDownloadSessionRequest",
]

View File

@ -0,0 +1,81 @@
"""Response schemas for the model-downloader API."""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel
ModelState = Literal["available", "missing", "downloading"]
class DownloadProgress(BaseModel):
"""Embedded in a model entry when its state is ``downloading``."""
bytes_downloaded: int
total_bytes: Optional[int] = None
progress: Optional[float] = None # fraction in [0,1]; null until total known
class ModelStatusEntry(BaseModel):
"""Everything the UI needs to render one row, in one shot.
``state`` reflects what the server has on disk + in-flight; ``file_size``
and ``is_hf_downloadable`` come from probes (intrinsic; cached).
The HF fields are populated for every poll (cached on the server),
so license-acceptance flips show up within one poll interval without
any frontend cache invalidation.
"""
state: ModelState
progress: Optional[DownloadProgress] = None
file_size: Optional[int] = None
# HF-only: True iff the server can fetch this URL with current auth
# state. False iff gated and lacking access. None for non-HF URLs.
is_hf_downloadable: Optional[bool] = None
class HfAuthStatus(BaseModel):
"""Snapshot of HF login state, embedded in availability response."""
token_available: bool
eligible: bool
class AvailabilityStatusResponse(BaseModel):
models: dict[str, ModelStatusEntry]
hf_auth: HfAuthStatus
class DownloadModelsResponse(BaseModel):
accepted: bool
scheduled: list[str]
class CancelDownloadSessionResponse(BaseModel):
cancelled: bool
class HfAuthTokenStatusResponse(BaseModel):
token_available: bool
username: Optional[str] = None
class HfAuthLoginStartResponse(BaseModel):
authorize_url: str
class HfAuthLogoutResponse(BaseModel):
logged_out: bool
__all__ = [
"ModelState",
"DownloadProgress",
"ModelStatusEntry",
"HfAuthStatus",
"AvailabilityStatusResponse",
"DownloadModelsResponse",
"CancelDownloadSessionResponse",
"HfAuthTokenStatusResponse",
"HfAuthLoginStartResponse",
"HfAuthLogoutResponse",
]

View File

@ -0,0 +1,179 @@
"""Process-wide registry of in-flight model downloads.
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
server-side model fetch. Designed to be safe with multiple concurrent
clients hitting the API: each model_id has at most one active session,
and the API rejects requests that conflict with in-flight downloads.
Cancellation is cooperative the download loop checks ``is_active`` on
its own session between chunks and raises ``DownloadCancelled`` when the
session has been removed from the registry. This avoids the complications
of ``Task.cancel()`` from outside the loop while still giving deterministic
rollback semantics (the worker is responsible for deleting its own
``.tmp`` on the cancel path).
"""
from __future__ import annotations
import logging
import os
import threading
from dataclasses import dataclass, field
from typing import Optional
from app.model_downloader.paths import iter_all_tmp_paths
class DownloadCancelled(Exception):
"""Raised by the streaming loop when its session has been removed
from the registry (cancellation request) and the worker should roll
back its ``.tmp`` file."""
@dataclass
class DownloadSession:
"""One in-flight download.
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
byte arrives and we know whether the response carries a
``Content-Length``. ``total_bytes`` mirrors that header when present.
"""
model_id: str
url: str
progress: Optional[float] = None
bytes_downloaded: int = 0
total_bytes: Optional[int] = None
# Sequence number used solely as identity for the cancellation check —
# so that "cancel + restart" doesn't get confused by stale workers.
epoch: int = field(default_factory=lambda: 0)
class DownloadServer:
"""Singleton registry of active downloads.
All mutation goes through this object so concurrent route handlers
see a consistent view. The ``_lock`` is a plain threading lock
because the registry is consulted from both the asyncio event-loop
thread (route handlers) and from any worker coroutines spawned to
perform downloads.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._sessions: dict[str, DownloadSession] = {}
self._epoch_counter = 0
self._orphan_sweep_done = False
# ----- lifecycle -----
def sweep_orphan_tmp_files(self) -> None:
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
Deferred off the import path so module load doesn't block on
filesystem I/O against potentially-slow mounts. Each route handler
that might create a new ``.tmp`` runs this exactly once.
"""
with self._lock:
if self._orphan_sweep_done:
return
self._orphan_sweep_done = True
for path in iter_all_tmp_paths():
try:
os.remove(path)
logging.info("[model_downloader] removed orphan tmp file: %s", path)
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
# ----- queries -----
def is_downloading(self, model_id: str) -> bool:
with self._lock:
return model_id in self._sessions
def get(self, model_id: str) -> Optional[DownloadSession]:
with self._lock:
return self._sessions.get(model_id)
def snapshot(self) -> dict[str, DownloadSession]:
"""Return a shallow copy of the current sessions map."""
with self._lock:
return dict(self._sessions)
# ----- mutations -----
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
"""Atomically register a new session iff none exists for ``model_id``.
Returns the new session on success, ``None`` if a session is already
in flight. Callers must check the return value the caller is the
sole owner of the session it gets back.
"""
with self._lock:
if model_id in self._sessions:
return None
self._epoch_counter += 1
session = DownloadSession(
model_id=model_id,
url=url,
epoch=self._epoch_counter,
)
self._sessions[model_id] = session
return session
def update_progress(
self,
session: DownloadSession,
bytes_downloaded: int,
total_bytes: Optional[int],
) -> None:
"""Update progress on a session. No-op if the session has been
removed (cancelled) caller should check ``is_active`` separately."""
with self._lock:
current = self._sessions.get(session.model_id)
if current is None or current.epoch != session.epoch:
return
current.bytes_downloaded = bytes_downloaded
current.total_bytes = total_bytes
if total_bytes and total_bytes > 0:
current.progress = min(1.0, bytes_downloaded / total_bytes)
def is_active(self, session: DownloadSession) -> bool:
"""True iff this exact session is still the registered one for
its model_id. False after cancellation, after completion, or if
another session has replaced it."""
with self._lock:
current = self._sessions.get(session.model_id)
return current is not None and current.epoch == session.epoch
def finish(self, session: DownloadSession) -> None:
"""Remove a completed (or cancelled) session from the registry.
Safe to call multiple times. Only removes if the epoch matches
we never accidentally evict a *newer* session for the same model_id.
"""
with self._lock:
current = self._sessions.get(session.model_id)
if current is not None and current.epoch == session.epoch:
del self._sessions[session.model_id]
def reset_for_tests(self) -> None:
"""Clear all sessions and reset the epoch counter. Test-only."""
with self._lock:
self._sessions.clear()
self._epoch_counter = 0
def cancel(self, model_id: str) -> bool:
"""Remove the session registered for ``model_id``.
Returns True if there was an active session to cancel. The worker
will discover the cancellation on its next ``is_active`` check
and roll back its ``.tmp`` file.
"""
with self._lock:
if model_id in self._sessions:
del self._sessions[model_id]
return True
return False
DOWNLOAD_SERVER = DownloadServer()

View File

@ -0,0 +1,205 @@
"""Streaming download worker with progress reporting and cancellation.
Each download writes to ``<final_path>.tmp`` and atomically renames into
place on success. Between chunks the worker checks the registry for
cancellation (via ``DownloadServer.is_active``) and rolls back its
``.tmp`` on cancel or on any error.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
import aiohttp
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadCancelled,
DownloadSession,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url
from app.model_downloader.http_client import get_session, parse_content_length
from app.model_downloader.paths import resolve_destination
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
async def stream_to_disk(session: DownloadSession) -> str:
"""Run a single download to completion or cancellation.
Returns the final on-disk path on success. Removes the ``.tmp`` and
raises on cancellation or failure. The session is finished
(removed from the registry) exactly once, here callers do not
need to call ``DOWNLOAD_SERVER.finish`` themselves.
"""
final_path, tmp_path = resolve_destination(session.model_id)
os.makedirs(os.path.dirname(final_path), exist_ok=True)
# Wipe any stale .tmp from a previous failed attempt before we start —
# otherwise a partial body could masquerade as our completed download
# when the rename finally happens.
_remove_if_exists(tmp_path)
bytes_seen = 0
try:
http = await get_session()
headers = _auth_headers_for(session.url)
logging.info(
"[model_downloader] starting GET %s (auth=%s)",
session.url, "yes" if "Authorization" in headers else "no",
)
async with http.get(
session.url,
allow_redirects=True,
timeout=REQUEST_TIMEOUT,
headers=headers,
) as resp:
if resp.status != 200:
# Capture a snippet of the response body so 4xx/5xx aren't
# opaque in the logs — HF returns JSON or HTML with a
# human-readable reason on failures.
body_snippet = await _read_short(resp)
logging.warning(
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
session.url, resp.status, str(resp.url), body_snippet,
)
raise DownloadError(
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
status=resp.status,
)
total = parse_content_length(resp.headers.get("Content-Length"))
DOWNLOAD_SERVER.update_progress(session, 0, total)
with open(tmp_path, "wb") as f:
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
# Cancellation check between chunks. Cheap and means
# cancellation latency is bounded by one chunk plus
# one ``write()`` — typically well under a second
# even on slow disks.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
f.write(chunk)
bytes_seen += len(chunk)
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
# Final cancellation check before we promote the .tmp to the real
# filename — avoids the awkward case where cancel arrives during
# the very last chunk and we'd otherwise commit anyway.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
# Atomic rename. os.replace is atomic within the same filesystem,
# which is guaranteed here because tmp lives alongside final_path.
os.replace(tmp_path, final_path)
logging.info(
"[model_downloader] downloaded %s (%d bytes) from %s",
session.model_id, bytes_seen, session.url,
)
return final_path
except DownloadCancelled:
logging.info("[model_downloader] cancelled: %s", session.model_id)
_remove_if_exists(tmp_path)
raise
except Exception as e:
logging.warning(
"[model_downloader] failed: %s from %s: %s: %s",
session.model_id, session.url, type(e).__name__, e,
exc_info=True,
)
_remove_if_exists(tmp_path)
raise
finally:
# In all terminal states (success / cancel / error) drop the
# session from the registry. Idempotent — only removes if we're
# still the live epoch for this model_id.
DOWNLOAD_SERVER.finish(session)
class DownloadError(Exception):
"""Network / protocol error during a download."""
def __init__(self, message: str, status: Optional[int] = None) -> None:
super().__init__(message)
self.status = status
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
"""Read up to ``limit`` bytes of a response body for logging.
Used to surface the JSON/HTML reason from an HF non-2xx response in
server logs instead of just the status code. Best-effort: any
error here is swallowed.
"""
try:
raw = await resp.content.read(limit)
return raw.decode("utf-8", errors="replace").strip()
except Exception:
return "<unreadable>"
def _auth_headers_for(url: str) -> dict[str, str]:
"""Return any auth headers we should add to the GET for ``url``.
For HuggingFace URLs we inject the user's OAuth access token as a
Bearer header this is HF's documented way to access gated repos
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
other host we send no extra headers; allowlisted public files
don't need them and we don't want to leak tokens to other hosts.
"""
if not is_hf_url(url):
return {}
tok = HF_AUTH_STORE.get_token_sync()
if tok is None or not tok.access_token:
return {}
return {"Authorization": f"Bearer {tok.access_token}"}
def _remove_if_exists(path: str) -> None:
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
"""Run a list of sessions one after the other.
Each session is independent: a failure or cancellation of one does
not abort the rest. Cancellations are observable via the registry
*before* a given download starts, so a session that's been
pre-cancelled (cancel before the worker reached it) just gets skipped.
"""
for session in sessions:
# If the session got cancelled before its turn, skip without
# touching disk. This is what makes the per-request "sequential
# but cancellable" semantic work.
if not DOWNLOAD_SERVER.is_active(session):
DOWNLOAD_SERVER.finish(session)
continue
try:
await stream_to_disk(session)
except DownloadCancelled:
# Already logged + tmp removed inside stream_to_disk.
continue
except Exception:
# stream_to_disk already logged. Continue with the rest of the batch.
continue
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
"""Kick off ``run_batch_sequential`` on the running event loop.
Returned task is fire-and-forget; the API handler returns immediately
after scheduling and clients observe progress via the polling endpoints.
"""
return asyncio.create_task(run_batch_sequential(sessions))

View File

@ -0,0 +1,238 @@
"""Per-URL probes for the unified availability endpoint.
Three cached/derived facts per URL:
- ``is_gated`` intrinsic to the model; cached forever once known.
Determined by ``auth_check(repo_id, token=None)``:
``GatedRepoError`` True, success False.
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
call. For non-gated URLs this is trivially True
(no HF call needed). For gated URLs we run
``auth_check`` with the stored token each call.
- ``file_size`` intrinsic to the file. Cached forever once
determined (including ``None`` on transient
failure we don't retry). We only attempt the
HEAD when we already know the URL is downloadable
to us; that way a failed-because-gated probe
never lands as a cached ``None``.
Caches are per-process, in-memory; small, no eviction needed for the
workflow-scale (~tens of URLs). Concurrent calls for the same URL
deduplicate via per-URL ``asyncio.Lock``.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Optional
import aiohttp
from huggingface_hub import HfApi
from huggingface_hub.errors import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
from app.model_downloader.http_client import get_session, parse_content_length
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
@dataclass
class ProbeResult:
file_size: Optional[int]
is_hf_downloadable: Optional[bool]
# --- caches -------------------------------------------------------------- #
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
# model — never changes for a given URL.
_is_gated_cache: dict[str, bool] = {}
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
# was attempted and produced no answer. **Only populated when we knew
# the URL was downloadable to us at probe time** — so gated-without-
# access never lands a ``None`` here that we'd be stuck with after login.
_file_size_cache: dict[str, Optional[int]] = {}
# Per-URL locks for single-flight probes — when multiple polls arrive
# in the same tick for the same URL, exactly one of them runs the HF
# call and the others wait on the result.
_locks: dict[str, asyncio.Lock] = {}
def _lock_for(url: str) -> asyncio.Lock:
lock = _locks.get(url)
if lock is None:
lock = asyncio.Lock()
_locks[url] = lock
return lock
def clear_caches_for_tests() -> None:
"""Test-only: drop everything."""
_is_gated_cache.clear()
_file_size_cache.clear()
_locks.clear()
# --- public entrypoint --------------------------------------------------- #
async def probe_url(url: str) -> ProbeResult:
"""Return downloadability + size for one URL, using caches where safe."""
if not is_hf_url(url):
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
# Size we still cache so we don't HEAD on every poll.
size = await _get_or_probe_size(url, token=None)
return ProbeResult(file_size=size, is_hf_downloadable=None)
repo_id = repo_id_from_url(url)
if repo_id is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Determine intrinsic gating once.
gated = await _resolve_is_gated(url, repo_id)
if gated is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Compute current-token downloadability per call.
tok = HF_AUTH_STORE.get_token_sync()
token_str: Optional[str] = tok.access_token if tok else None
if not gated:
is_hf_downloadable: Optional[bool] = True
else:
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
if is_hf_downloadable is True:
size = await _get_or_probe_size(url, token=token_str)
else:
# Skip the HEAD entirely — would 401 and we'd be stuck with
# cached None that survives a later login.
size = None
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
# --- gated/auth probes --------------------------------------------------- #
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
"""Decide once whether ``repo_id`` is a gated repo."""
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
async with _lock_for(url):
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
try:
await asyncio.to_thread(_auth_check_sync, repo_id, None)
_is_gated_cache[url] = False
return False
except GatedRepoError:
_is_gated_cache[url] = True
return True
except RepositoryNotFoundError:
# Repo doesn't exist publicly. Treat as gated — we can't
# serve it without auth, and an authenticated check might
# still succeed if it's a private repo the user can see.
_is_gated_cache[url] = True
return True
except (HfHubHTTPError, Exception) as e:
logging.debug(
"[hf_auth] is_gated probe failed for %s (will retry): %s",
repo_id, e,
)
return None # don't cache; retry next call
async def _auth_check_with_token(
repo_id: str, token: Optional[str]
) -> Optional[bool]:
"""Auth-check with the supplied token. True/False/None per outcome."""
try:
await asyncio.to_thread(_auth_check_sync, repo_id, token)
return True
except GatedRepoError:
return False
except RepositoryNotFoundError:
return False
except HfHubHTTPError as e:
# 401/403 covers org-SSO-required, revoked tokens, and similar —
# all of which mean "can't fetch right now" from the user's POV.
status = getattr(getattr(e, "response", None), "status_code", None)
if status in (401, 403):
return False
logging.debug(
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
)
return None
except Exception as e:
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
return None
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
HfApi().auth_check(repo_id, token=token)
# --- size probe ---------------------------------------------------------- #
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
"""Return the cached size or HEAD the URL once and cache the result."""
if url in _file_size_cache:
return _file_size_cache[url]
async with _lock_for(url):
if url in _file_size_cache:
return _file_size_cache[url]
size = await _probe_size_once(url, token=token)
_file_size_cache[url] = size
return size
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
"""HEAD the URL and return the file size in bytes, or None on failure.
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
The real file size lives in the non-standard ``X-Linked-Size`` header
on that 302 response (``Content-Length`` is the redirect-body length).
Disabling redirect-follow lets us read either header on the same
response:
- LFS files: 302 + ``X-Linked-Size``
- Small/non-LFS files: 200 + ``Content-Length``
"""
headers = {"Authorization": f"Bearer {token}"} if token else {}
try:
session = await get_session()
async with session.head(
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
) as resp:
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
if linked is not None:
return linked
if resp.status == 200:
return parse_content_length(resp.headers.get("Content-Length"))
return None
except (aiohttp.ClientError, TimeoutError, OSError):
return None
# Backward-compat shim so consumers that still import the old name keep
# building during the refactor; can be removed once routes are updated.
MetadataProbeResult = ProbeResult

View File

@ -0,0 +1,106 @@
"""In-memory token cache with lazy disk persistence + refresh.
Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask
``get_valid_token()``; the store transparently refreshes from disk
on first use, refreshes via the OAuth refresh_token if the cached
access_token is expired, and returns ``None`` if neither path works.
The refresh path imports ``oauth.refresh_access_token`` lazily to
avoid an import cycle (oauth needs the store to save tokens it
acquires).
"""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Optional
from app.model_downloader.hf_auth.token_store import (
Token,
delete_token,
load_token,
save_token,
)
class HfAuthStore:
def __init__(self) -> None:
self._lock = threading.Lock()
self._token: Optional[Token] = None
self._loaded_from_disk = False
def _ensure_loaded(self) -> None:
"""Read the disk token into memory on first access."""
if self._loaded_from_disk:
return
with self._lock:
if self._loaded_from_disk:
return
self._token = load_token()
self._loaded_from_disk = True
def has_token(self) -> bool:
"""Cheap check: is there any token in memory?
Does not attempt refresh; an expired-but-refreshable token still
counts as "logged in" from the user's perspective.
"""
self._ensure_loaded()
return self._token is not None
def set_token(self, token: Token) -> None:
"""Replace the in-memory token and persist to disk."""
with self._lock:
self._token = token
self._loaded_from_disk = True
save_token(token)
def clear(self) -> None:
"""Forget the token in memory and on disk (logout)."""
with self._lock:
self._token = None
self._loaded_from_disk = True
delete_token()
def get_token_sync(self) -> Optional[Token]:
"""Return the cached token without refreshing.
Sync callers (e.g. constructing an Authorization header in a
non-async path) use this. They accept an expired token over
``None``; HF will simply return 401 and the caller can decide
what to do.
"""
self._ensure_loaded()
return self._token
async def get_valid_token(self) -> Optional[Token]:
"""Return a fresh token, refreshing via OAuth if necessary.
Returns ``None`` if there's no cached token at all, or if the
cached token is expired and refresh failed. Callers should
treat that as "user is not logged in".
"""
self._ensure_loaded()
tok = self._token
if tok is None:
return None
if tok.is_valid():
return tok
if not tok.refresh_token:
return None
# Lazy import to avoid the oauth ↔ store import cycle.
from app.model_downloader.hf_auth.oauth import refresh_access_token
try:
refreshed = await refresh_access_token(tok.refresh_token)
except Exception as e:
logging.warning("[hf_auth] token refresh failed: %s", e)
return None
self.set_token(refreshed)
return refreshed
HF_AUTH_STORE = HfAuthStore()

View File

@ -0,0 +1,55 @@
"""Whether this deployment is allowed to do interactive HF OAuth.
We only let the server hold a HuggingFace token under a strict trust
assumption: this is a *single tenant local* install. Concretely:
- The server is bound to a loopback address. SSH tunneling /
reverse-proxies can defeat this, but it's the strongest signal
we have without an authentication system.
- ``--multi-user`` is off. A shared token used implicitly by multiple
declared users would be a footgun one user's gated downloads
would silently authenticate as another.
Anything else and the frontend hides the HF login UI entirely; gated
models continue to show the "acquire it manually" message.
"""
from __future__ import annotations
import ipaddress
import socket
from comfy.cli_args import args
def _is_loopback(host: str | None) -> bool:
"""Duplicates ``server.is_loopback`` (small, no shared module yet).
Resolves a host or IP literal to whether it lives on the loopback
interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for
``0.0.0.0`` / ``::`` because those are bind-all wildcards, not
loopback.
"""
if host is None:
return False
try:
return ipaddress.ip_address(host).is_loopback
except ValueError:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for _family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
loopback = True
except socket.gaierror:
pass
return loopback
def is_hf_auth_eligible() -> bool:
"""True iff this deployment may surface the HF OAuth flow."""
return _is_loopback(args.listen) and not args.multi_user

View File

@ -0,0 +1,277 @@
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
Wired so that ``POST /api/hf-auth-login-start`` can:
1. Generate state + PKCE verifier/challenge in this process.
2. Spin up a short-lived loopback HTTP server at port 41954 to
receive the redirect callback from HF.
3. Return the ``authorize_url`` for the frontend to open in a new tab.
After the user grants consent on huggingface.co, HF redirects to the
local callback URL with ``code`` and ``state``. The callback server
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
itself down.
Before this can be exercised end-to-end a maintainer must register a
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
below. See the comment above the constant for the exact steps.
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import secrets
import threading
import time
import aiohttp
from aiohttp import web
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.token_store import Token
from app.model_downloader.http_client import ssl_context
# --- HF OAuth app registration -------------------------------------------- #
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
# under a Comfy-Org-controlled HF account and substitute its client_id here.
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
# ("HuggingFace OAuth app setup" section). Short version:
# 1. huggingface.co → Settings → Connected Apps → "Create app"
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
# ``gated-repos`` (Repository Access). Leave everything else off.
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
# change ``CALLBACK_PORT``.
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
# The client_id is not a secret (it travels through the user's browser in
# plaintext); HF's "Public app" type means there's no client secret to
# manage — PKCE replaces it.
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
CALLBACK_HOST = "127.0.0.1"
CALLBACK_PORT = 41954
CALLBACK_PATH = "/api/auth/huggingface/callback"
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
TOKEN_URL = "https://huggingface.co/oauth/token"
# Minimal scope set for the feature:
# - openid : required by HF when the app uses OIDC at all
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
# settings UI; cosmetic but expected
# - gated-repos : grants the token enough to call ``auth_check`` and
# download files from public gated repos the user has
# accepted the license for. The wider ``read-repos`` scope
# would also work (it includes ``gated-repos``) but it
# additionally grants private-repo read access, which we
# don't need and which makes the consent screen scarier
# for the user.
SCOPE = "openid profile gated-repos"
# Maximum time the callback server stays up waiting for the user to
# complete consent on huggingface.co. Past this, the port closes and
# the user has to click "Log in" again.
CALLBACK_TIMEOUT_SECS = 300
# Process-wide lock so two simultaneous /api/hf-auth-login-start
# requests don't fight over port CALLBACK_PORT.
_OAUTH_LOCK = threading.Lock()
class OAuthInProgressError(Exception):
"""Another OAuth attempt is already running."""
class OAuthCallbackError(Exception):
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
# --- PKCE primitives ------------------------------------------------------ #
def _make_pkce() -> tuple[str, str, str]:
"""Return ``(verifier, challenge, state)``.
Verifier never leaves this process. Challenge and state travel
through the user's browser. State is checked on the callback to
prevent a malicious cross-origin redirect from injecting a token.
"""
verifier = secrets.token_urlsafe(64)
challenge = (
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
.rstrip(b"=")
.decode("ascii")
)
state = secrets.token_urlsafe(32)
return verifier, challenge, state
def _build_authorize_url(challenge: str, state: str) -> str:
from urllib.parse import urlencode
params = {
"client_id": HF_CLIENT_ID,
"redirect_uri": REDIRECT_URI,
"response_type": "code",
"scope": SCOPE,
"state": state,
"code_challenge": challenge,
"code_challenge_method": "S256",
}
return f"{AUTHORIZE_URL}?{urlencode(params)}"
# --- Token exchange ------------------------------------------------------- #
async def _exchange_code(code: str, verifier: str) -> Token:
"""Trade the authorization code for an access+refresh token pair."""
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": REDIRECT_URI,
"client_id": HF_CLIENT_ID,
"code_verifier": verifier,
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
resp.raise_for_status()
body = await resp.json()
return Token(
access_token=body["access_token"],
refresh_token=body.get("refresh_token"),
expires_at=time.time() + float(body.get("expires_in", 3600)),
scope=body.get("scope", SCOPE),
)
async def refresh_access_token(refresh_token: str) -> Token:
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": HF_CLIENT_ID,
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
resp.raise_for_status()
body = await resp.json()
return Token(
access_token=body["access_token"],
# If HF doesn't rotate refresh tokens, keep using the existing one.
refresh_token=body.get("refresh_token", refresh_token),
expires_at=time.time() + float(body.get("expires_in", 3600)),
scope=body.get("scope", SCOPE),
)
# --- Callback server ------------------------------------------------------ #
async def start_login_flow() -> str:
"""Begin one OAuth attempt: spawn the callback server, return the URL.
Returns the URL the frontend should open in a new tab. Raises
``OAuthInProgressError`` if another attempt is already running.
The callback server runs in the background until the user
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
either way the lock + port are released afterward.
"""
if not _OAUTH_LOCK.acquire(blocking=False):
raise OAuthInProgressError()
verifier, challenge, state = _make_pkce()
authorize_url = _build_authorize_url(challenge, state)
# Fire the callback server on the running loop and return.
asyncio.create_task(_run_callback_server(verifier, state))
return authorize_url
async def _run_callback_server(verifier: str, expected_state: str) -> None:
"""Listen for HF's redirect once, capture the token, then shut down."""
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
async def handler(request: web.Request) -> web.Response:
try:
if request.query.get("state") != expected_state:
return web.Response(status=400, text="state mismatch")
err = request.query.get("error")
if err:
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
return web.Response(status=400, text=f"OAuth error: {err}")
code = request.query.get("code")
if not code:
return web.Response(status=400, text="missing code")
tok = await _exchange_code(code, verifier)
if not received.done():
received.set_result(tok)
return web.Response(
content_type="text/html",
text=(
"<html><body style='font-family:sans-serif;padding:40px'>"
"<h2>HuggingFace login successful</h2>"
"<p>You can close this tab and return to ComfyUI.</p>"
"</body></html>"
),
)
except Exception as exc:
if not received.done():
received.set_exception(exc)
return web.Response(status=500, text=str(exc))
app = web.Application()
app.router.add_get(CALLBACK_PATH, handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
try:
await site.start()
except OSError as e:
# Port already in use (or some other socket-bind failure). Release
# the lock so a future attempt has a chance to succeed.
logging.warning("[hf_auth] could not bind callback port: %s", e)
_OAUTH_LOCK.release()
return
try:
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
except asyncio.TimeoutError:
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
return
except OAuthCallbackError as e:
logging.warning("[hf_auth] OAuth callback error: %s", e)
return
except Exception as e:
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
return
else:
HF_AUTH_STORE.set_token(token)
logging.info("[hf_auth] OAuth login complete")
finally:
await runner.cleanup()
if _OAUTH_LOCK.locked():
_OAUTH_LOCK.release()
def is_login_in_progress() -> bool:
"""True iff a callback server is currently bound + waiting."""
return _OAUTH_LOCK.locked()
# Re-export for callers that only want the URL builder (e.g. tests).
__all__ = [
"start_login_flow",
"refresh_access_token",
"is_login_in_progress",
"OAuthInProgressError",
"CALLBACK_TIMEOUT_SECS",
]

View File

@ -0,0 +1,89 @@
"""On-disk persistence for the HuggingFace OAuth token.
The token shape mirrors what HF returns on the token exchange: an
``access_token``, an optional ``refresh_token``, the absolute epoch at
which the access token expires, and the granted scope. We persist
this so logging in once survives ComfyUI restarts; the file is mode
``0600`` so only the OS user can read it.
"""
from __future__ import annotations
import json
import logging
import os
import stat
import time
from dataclasses import asdict, dataclass
from typing import Optional
import folder_paths
# Treat a token as expired this many seconds before its server-reported
# ``expires_at`` so we don't try to use a token mid-request only for it
# to flip stale between auth_check and the actual GET.
EXPIRY_BUFFER_SECS = 60
TOKEN_FILENAME = "hf_auth_token.json"
@dataclass
class Token:
"""One OAuth token + the metadata we need to use it."""
access_token: str
refresh_token: Optional[str]
expires_at: float # absolute epoch seconds
scope: str = ""
def is_valid(self) -> bool:
"""True iff we can use this token right now."""
return (
bool(self.access_token)
and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS)
)
def _token_path() -> str:
base = folder_paths.get_user_directory()
return os.path.join(base, TOKEN_FILENAME)
def load_token() -> Optional[Token]:
"""Read the persisted token, returning ``None`` if absent or corrupt."""
path = _token_path()
if not os.path.exists(path):
return None
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return Token(**data)
except (OSError, json.JSONDecodeError, TypeError) as e:
logging.warning("[hf_auth] could not load token at %s: %s", path, e)
return None
def save_token(token: Token) -> None:
"""Atomically write the token with 0600 permissions."""
path = _token_path()
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(asdict(token), f)
os.replace(tmp, path)
try:
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
except OSError as e:
# On Windows / weird filesystems chmod may be a no-op; not fatal.
logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e)
def delete_token() -> None:
"""Remove the persisted token; no-op if it doesn't exist."""
path = _token_path()
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[hf_auth] could not remove token at %s: %s", path, e)

View File

@ -0,0 +1,41 @@
"""Parsers for the ``huggingface.co`` URL shape we accept in workflows.
The download API accepts URLs of the form
``https://huggingface.co/<org>/<repo>/resolve/<rev>/<path/to/file>``.
We need to recover ``<org>/<repo>`` (the *repo_id*) from such URLs for
``huggingface_hub`` API calls (notably ``HfApi.auth_check``).
"""
from __future__ import annotations
from typing import Optional
from urllib.parse import urlparse
_HF_HOST = "huggingface.co"
def is_hf_url(url: str) -> bool:
"""Cheap host check — does this URL point at huggingface.co?"""
try:
return urlparse(url).hostname == _HF_HOST
except ValueError:
return False
def repo_id_from_url(url: str) -> Optional[str]:
"""Extract ``<org>/<repo>`` from an HF model file URL.
Returns ``None`` if the URL isn't on huggingface.co or doesn't look
like a model-file URL. The expected shape is
``/<org>/<repo>/resolve/<rev>/<path>`` anything else
(datasets, spaces, /tree/, /blob/, ) we treat as out of scope here.
"""
if not is_hf_url(url):
return None
parts = urlparse(url).path.lstrip("/").split("/")
if len(parts) < 4 or parts[2] != "resolve":
return None
org, repo = parts[0], parts[1]
if not org or not repo:
return None
return f"{org}/{repo}"

View File

@ -0,0 +1,63 @@
"""Lazy module-level aiohttp ClientSession.
A single shared session means TLS handshakes are reused across HEAD probes
and the subsequent GETs to the same host (HuggingFace is the dominant
case), which is a noticeable speedup on cold connections.
We deliberately don't close the session at process exit — aiohttp's
warning about unclosed sessions is benign at shutdown, and adding atexit
plumbing buys nothing because the OS reclaims the sockets anyway. The
session lifetime is the lifetime of the Python process.
"""
from __future__ import annotations
import asyncio
import ssl
from typing import Optional
import aiohttp
import certifi
# Larger per-host pool than aiohttp's default (=100 total / =0 per host)
# so concurrent gated probes + a download to the same host don't queue.
_CONNECTOR_LIMIT_PER_HOST = 8
_session: Optional[aiohttp.ClientSession] = None
_lock = asyncio.Lock()
def ssl_context() -> ssl.SSLContext:
"""TLS context pinned to certifi's CA bundle.
aiohttp's default context uses the OS trust store, which isn't wired up
on some Python installs (python.org macOS, slim containers) there TLS
to huggingface.co fails with CERTIFICATE_VERIFY_FAILED.
"""
return ssl.create_default_context(cafile=certifi.where())
async def get_session() -> aiohttp.ClientSession:
"""Return the shared session, creating it on first call."""
global _session
if _session is not None and not _session.closed:
return _session
async with _lock:
if _session is None or _session.closed:
connector = aiohttp.TCPConnector(
limit_per_host=_CONNECTOR_LIMIT_PER_HOST,
ssl=ssl_context(),
)
_session = aiohttp.ClientSession(connector=connector)
return _session
def parse_content_length(value: Optional[str]) -> Optional[int]:
"""Parse a byte-count header value, or None if absent/malformed/negative."""
if not value:
return None
try:
n = int(value)
except ValueError:
return None
return n if n >= 0 else None

View File

@ -0,0 +1,93 @@
"""Path resolution for model downloads.
Model identifiers used across the download API are *relative destination
paths* of the form ``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``).
This module turns one of those identifiers into an absolute on-disk path
under one of ComfyUI's registered model folders, while rejecting unknown
folders, path traversal, and other ill-formed inputs.
"""
import os
import re
from typing import Optional, Tuple
import folder_paths
# Constrain components so a model_id can never escape its target directory.
# - directory: a single path segment of safe chars
# - filename: a single path segment of safe chars, must end with a model extension
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$")
class InvalidModelId(ValueError):
"""Raised when a model_id is syntactically invalid or refers to an
unknown model folder."""
def parse_model_id(model_id: str) -> Tuple[str, str]:
"""Split ``<directory>/<filename>`` and validate both components.
Returns ``(directory, filename)``. Raises ``InvalidModelId`` on
malformed input. Does NOT touch the filesystem.
"""
if not isinstance(model_id, str) or "/" not in model_id:
raise InvalidModelId(f"model_id must be '<directory>/<filename>', got {model_id!r}")
directory, _, filename = model_id.partition("/")
if "/" in filename or not directory or not filename:
raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}")
if not _SEGMENT_RE.match(directory):
raise InvalidModelId(f"invalid directory segment {directory!r}")
if not _SEGMENT_RE.match(filename):
raise InvalidModelId(f"invalid filename segment {filename!r}")
if directory not in folder_paths.folder_names_and_paths:
raise InvalidModelId(f"unknown model folder {directory!r}")
return directory, filename
def resolve_existing(model_id: str) -> Optional[str]:
"""Return the absolute path of an installed model, or None if missing.
Honours ``extra_model_paths.yaml`` transparently via
``folder_paths.get_full_path``.
"""
directory, filename = parse_model_id(model_id)
return folder_paths.get_full_path(directory, filename)
def resolve_destination(model_id: str) -> Tuple[str, str]:
"""Return ``(final_path, tmp_path)`` for a download.
Downloads land at the first registered path for the model's directory
(the "primary" location). The ``.tmp`` sibling is used as the write
target and atomically renamed on success.
"""
directory, filename = parse_model_id(model_id)
roots = folder_paths.get_folder_paths(directory)
if not roots:
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
root = roots[0]
final_path = os.path.join(root, filename)
tmp_path = final_path + ".tmp"
return final_path, tmp_path
def iter_all_tmp_paths():
"""Yield every ``*.tmp`` file under every registered model folder.
Used at startup to sweep orphans left by crashed/restarted downloads.
"""
seen_roots: set[str] = set()
for directory in folder_paths.folder_names_and_paths.keys():
for root in folder_paths.get_folder_paths(directory):
if root in seen_roots or not os.path.isdir(root):
continue
seen_roots.add(root)
try:
for entry in os.scandir(root):
if entry.is_file() and entry.name.endswith(".tmp"):
yield entry.path
except OSError:
# Folder might be unreadable / missing on certain mounts —
# not fatal, just skip it.
continue

View File

@ -98,12 +98,24 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
# Default server capabilities
def _hf_auth_eligible_at_startup() -> bool:
"""Snapshot eligibility once at feature-flag init time.
Imports lazily because the flags module loads very early in the
server boot sequence earlier than the model_downloader package.
"""
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
return is_hf_auth_eligible()
_CORE_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
"assets": args.enable_assets,
"server_side_model_downloads": True,
"hf_auth_eligible": _hf_auth_eligible_at_startup(),
}
# CLI-provided flags cannot overwrite core flags

File diff suppressed because it is too large Load Diff

View File

@ -188,6 +188,49 @@ components:
- id
- updated_at
type: object
AvailabilityStatusRequest:
description: |
Models to query — each entry is `model_id → URL`. The URL lets
the server compute file_size + is_hf_downloadable on the same
request, eliminating the need for a separate metadata endpoint.
properties:
models:
additionalProperties:
type: string
description: model_id → URL declared in the workflow.
type: object
required:
- models
type: object
AvailabilityStatusResponse:
description: Per-model state + metadata + HF auth snapshot.
properties:
hf_auth:
$ref: '#/components/schemas/HfAuthStatus'
models:
additionalProperties:
$ref: '#/components/schemas/ModelStatusEntry'
type: object
required:
- models
- hf_auth
type: object
CancelDownloadSessionRequest:
description: Request to cancel an in-flight download for a given model_id.
properties:
model_id:
type: string
required:
- model_id
type: object
CancelDownloadSessionResponse:
description: Result of a cancellation request.
properties:
cancelled:
type: boolean
required:
- cancelled
type: object
CreateWorkflowRequest:
description: Request body for creating a new saved workflow.
properties:
@ -230,6 +273,51 @@ components:
- base_version
- workflow_json
type: object
DownloadModelsRequest:
description: Map of model_id → URL of files to fetch into the model folders.
properties:
models:
additionalProperties:
type: string
description: model_id → URL of models to download.
type: object
required:
- models
type: object
DownloadModelsResponse:
description: Acknowledgement that downloads have been scheduled.
properties:
accepted:
description: Always true; the request was scheduled.
type: boolean
scheduled:
description: The list of model_ids whose downloads are now in-flight.
items:
type: string
type: array
required:
- accepted
- scheduled
type: object
DownloadProgress:
description: In-flight download progress; embedded in ModelStatusEntry.
properties:
bytes_downloaded:
format: int64
type: integer
progress:
description: Fraction in [0,1]; null until total_bytes is known.
format: float
nullable: true
type: number
total_bytes:
description: Content-Length when supplied by the source.
format: int64
nullable: true
type: integer
required:
- bytes_downloaded
type: object
ErrorResponse:
description: Standard error response with a machine-readable code and human-readable message.
properties:
@ -394,6 +482,46 @@ components:
- name
- info
type: object
HfAuthLoginStartResponse:
description: URL the frontend should open in a new tab to complete login.
properties:
authorize_url:
type: string
required:
- authorize_url
type: object
HfAuthLogoutResponse:
description: Result of the logout call (always logged_out = true).
properties:
logged_out:
type: boolean
required:
- logged_out
type: object
HfAuthStatus:
description: Inline snapshot of the server's HuggingFace OAuth state.
properties:
eligible:
description: True iff this deployment can surface interactive HF login.
type: boolean
token_available:
description: True iff a token (possibly expired but refreshable) is stored.
type: boolean
required:
- token_available
- eligible
type: object
HfAuthTokenStatusResponse:
description: Whether the server holds an HF OAuth token + resolved username.
properties:
token_available:
type: boolean
username:
nullable: true
type: string
required:
- token_available
type: object
HistoryDetailEntry:
description: History entry with full prompt data
properties:
@ -798,6 +926,40 @@ components:
- name
- folders
type: object
ModelStatusEntry:
description: Everything the UI needs to render one row of the model.
properties:
file_size:
description: Bytes, when known. Cached server-side per URL.
format: int64
nullable: true
type: integer
is_hf_downloadable:
description: |
HuggingFace-only signal. True if the server can fetch this
URL with its current auth state (public, or gated-with-access).
False if gated and lacking access. Null for non-HF URLs and
for HF URLs whose probe failed entirely.
nullable: true
type: boolean
progress:
allOf:
- $ref: '#/components/schemas/DownloadProgress'
description: Present when `state == downloading`.
nullable: true
state:
description: |
`available` — file is on disk.
`missing` — not on disk and no download in flight.
`downloading` — server is currently fetching the file.
enum:
- available
- missing
- downloading
type: string
required:
- state
type: object
NodeInfo:
description: Metadata describing a single ComfyUI node type and its inputs/outputs.
properties:
@ -2338,6 +2500,60 @@ paths:
summary: Get tag histogram for filtered assets
tags:
- file
/api/cancel-model-download-session:
post:
description: |
Cancels the download session for the given model_id. The worker
observes the cancellation between chunks, removes its partial `.tmp`
file, and exits without writing the destination path.
operationId: postCancelModelDownloadSession
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/CancelDownloadSessionRequest'
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/CancelDownloadSessionResponse'
description: Session cancelled.
"404":
description: No active download for that model_id.
summary: Cancel an in-flight server-side model download
tags:
- model
/api/download-models:
post:
description: |
Schedules downloads for every model_id in the request map. Returns
immediately after validation; progress is observed via
`/api/models-availability-status`. Fails atomically if any model
is already on disk, already downloading, gated, or has a URL that
is not on the server's allowlist (HuggingFace, Civitai, localhost).
operationId: postDownloadModels
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/DownloadModelsRequest'
required: true
responses:
"202":
content:
application/json:
schema:
$ref: '#/components/schemas/DownloadModelsResponse'
description: Downloads accepted and scheduled.
"400":
description: One of the requested models is invalid, gated, or has a non-allowed URL.
"409":
description: One of the requested models is already on disk or downloading.
summary: Start a server-side download of one or more models
tags:
- model
/api/embeddings:
get:
description: Returns the list of text-encoder embeddings available on disk.
@ -2639,6 +2855,66 @@ paths:
summary: Get a specific subgraph blueprint
tags:
- workflow
/api/hf-auth-login-start:
post:
description: |
Spawns a short-lived loopback callback server (port 41954) and
returns the URL the frontend should open in a new tab. After the
user grants consent, HF redirects back to the callback URL with
an authorization code; the server exchanges that for a token and
persists it. Subsequent `/api/hf-auth-token-status` calls will
return `token_available: true`. Rejected with 403 if the
deployment is not eligible (not loopback or in --multi-user mode);
409 if another login attempt is already in progress.
operationId: postHfAuthLoginStart
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/HfAuthLoginStartResponse'
description: Login flow started; `authorize_url` is ready.
"403":
description: Deployment is not eligible for interactive HF login.
"409":
description: Another login attempt is already in progress.
summary: Begin a HuggingFace OAuth login flow
tags:
- model
/api/hf-auth-logout:
post:
description: Clears the in-memory cache and removes the on-disk token file.
operationId: postHfAuthLogout
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/HfAuthLogoutResponse'
description: Logged out (idempotent — succeeds even if no token was held).
summary: Drop the stored HuggingFace OAuth token
tags:
- model
/api/hf-auth-token-status:
get:
description: |
Returns `token_available: true` when the server has a token
in memory (or on disk) for HuggingFace, irrespective of whether
the access_token is currently fresh — an expired one with a
refresh_token still counts as "logged in" because we'll refresh
transparently on next use. If a username is resolvable via
`HfApi.whoami` we return that too, for the settings UI.
operationId: getHfAuthTokenStatus
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/HfAuthTokenStatusResponse'
description: Token status.
summary: Whether the server holds a usable HuggingFace OAuth token
tags:
- model
/api/history:
post:
deprecated: true
@ -3141,6 +3417,44 @@ paths:
summary: Cancel multiple jobs
tags:
- workflow
/api/models-availability-status:
post:
description: |
Given a map of `{model_id: url}` (model_id is
`<directory>/<filename>`), returns per-id state plus the
metadata the UI needs to render the row:
- `state` — one of `available` / `missing` / `downloading`
- `progress` — embedded when `state == downloading`
- `file_size` — bytes (when known)
- `is_hf_downloadable` — for HF URLs only: true if the
server can currently fetch the file with its stored auth
state, false if gated and lacking access, null otherwise
Designed for 1 Hz polling. `file_size` and the intrinsic
"is this model gated" check are cached server-side per URL;
`is_hf_downloadable` is recomputed per call so license
acceptance and login/logout transitions show up within one
poll interval without any client-side cache plumbing.
operationId: postModelsAvailabilityStatus
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/AvailabilityStatusRequest'
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/AvailabilityStatusResponse'
description: Per-model status and metadata.
"400":
description: Malformed request body.
summary: Unified per-model status + metadata for the polling UI
tags:
- model
/api/node_replacements:
get:
description: |
@ -5087,3 +5401,5 @@ tags:
name: queue
- description: Job lifecycle queries
name: job
- description: Server-side model availability and downloads
name: model

View File

@ -9,6 +9,7 @@ numpy>=1.25.0
einops
transformers>=4.50.3
tokenizers>=0.13.3
huggingface_hub
sentencepiece
safetensors>=0.4.2
aiohttp>=3.11.8

View File

@ -47,6 +47,7 @@ from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@ -256,6 +257,7 @@ class PromptServer():
else:
register_assets_routes(self.app)
asset_seeder.disable()
register_model_downloader_routes(self.app)
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None

View File

@ -0,0 +1,567 @@
"""Unit tests for the HuggingFace auth subsystem.
Covers:
- token store: save/load roundtrip, chmod 0600, atomic write, delete
- eligibility under various CLI-arg combinations
- URL parsing (huggingface.co host detection + repo_id extraction)
- HF-aware gated_detection.probe_url (mocked auth_check)
- HF auth routes (token status, login start with eligibility gate, logout)
- PKCE primitives + authorize URL shape
The OAuth callback server itself isn't exercised end-to-end here — that
requires a real HF server. We test the components (state checking,
URL building, code-exchange request shape) instead.
"""
from __future__ import annotations
import json
import os
import stat
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from app.model_downloader.api.routes import register_routes
from app.model_downloader.hf_auth import oauth
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore
from app.model_downloader.hf_auth.token_store import (
EXPIRY_BUFFER_SECS,
Token,
delete_token,
load_token,
save_token,
)
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
pytestmark = pytest.mark.asyncio
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def patched_user_dir(tmp_path):
"""Redirect ``folder_paths.get_user_directory`` so the token file
lands in an isolated tmp_path instead of the real user dir."""
user_dir = tmp_path / "user"
user_dir.mkdir()
with patch("folder_paths.get_user_directory", return_value=str(user_dir)):
yield user_dir
@pytest.fixture
def fresh_auth_store():
"""Wipe singleton state between tests: auth + probe caches."""
from app.model_downloader import gated_detection
HF_AUTH_STORE._token = None
HF_AUTH_STORE._loaded_from_disk = False
gated_detection.clear_caches_for_tests()
yield HF_AUTH_STORE
HF_AUTH_STORE._token = None
HF_AUTH_STORE._loaded_from_disk = False
gated_detection.clear_caches_for_tests()
@pytest.fixture
def app(patched_user_dir, fresh_auth_store):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# URL parsing
# --------------------------------------------------------------------------- #
def test_is_hf_url_recognises_huggingface_co():
assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors")
assert is_hf_url("https://huggingface.co/abc")
assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors")
assert not is_hf_url("https://civitai.com/x.safetensors")
def test_repo_id_from_url_extracts_org_and_repo():
url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors"
assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR"
def test_repo_id_from_url_handles_nested_path():
url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors"
assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3"
def test_repo_id_from_url_returns_none_for_non_hf():
assert repo_id_from_url("https://civitai.com/x.safetensors") is None
def test_repo_id_from_url_returns_none_for_non_resolve_paths():
assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None
assert repo_id_from_url("https://huggingface.co/org") is None
# --------------------------------------------------------------------------- #
# Token store
# --------------------------------------------------------------------------- #
def test_token_store_roundtrip(patched_user_dir):
tok = Token(
access_token="hf_abc",
refresh_token="rf_def",
expires_at=9999999999.0,
scope="openid profile",
)
save_token(tok)
loaded = load_token()
assert loaded == tok
def test_token_store_writes_0600(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
save_token(tok)
path = os.path.join(patched_user_dir, "hf_auth_token.json")
mode = stat.S_IMODE(os.stat(path).st_mode)
# On Windows we silently no-op chmod; allow either the intended
# mode or whatever umask the OS gave us.
if os.name == "posix":
assert mode == 0o600
def test_token_store_delete_removes_file(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
save_token(tok)
delete_token()
path = os.path.join(patched_user_dir, "hf_auth_token.json")
assert not os.path.exists(path)
# Idempotent: second delete is fine.
delete_token()
def test_token_store_load_returns_none_for_missing_file(patched_user_dir):
assert load_token() is None
def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir):
path = os.path.join(patched_user_dir, "hf_auth_token.json")
with open(path, "w") as f:
f.write("not json {")
assert load_token() is None
def test_token_is_valid_uses_buffer(patched_user_dir):
import time
fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
nearly_expired = Token(
access_token="x",
refresh_token=None,
expires_at=time.time() + EXPIRY_BUFFER_SECS - 1,
)
assert fresh.is_valid()
assert not nearly_expired.is_valid()
# --------------------------------------------------------------------------- #
# Auth store
# --------------------------------------------------------------------------- #
def test_auth_store_loads_lazily(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
save_token(tok)
store = HfAuthStore()
assert store.has_token()
assert store.get_token_sync() == tok
def test_auth_store_set_persists(patched_user_dir):
store = HfAuthStore()
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
store.set_token(tok)
# Token is on disk now — a fresh store sees it.
assert HfAuthStore().get_token_sync() == tok
def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir):
store = HfAuthStore()
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
store.set_token(tok)
store.clear()
assert not store.has_token()
assert HfAuthStore().get_token_sync() is None
async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir):
store = HfAuthStore()
import time
tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
store.set_token(tok)
fetched = await store.get_valid_token()
assert fetched == tok
async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir):
store = HfAuthStore()
import time
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
refreshed = Token(
access_token="new",
refresh_token="rf",
expires_at=time.time() + 3600,
)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(return_value=refreshed),
):
result = await store.get_valid_token()
assert result == refreshed
async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir):
store = HfAuthStore()
import time
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(side_effect=RuntimeError("HF down")),
):
result = await store.get_valid_token()
assert result is None
# --------------------------------------------------------------------------- #
# Eligibility
# --------------------------------------------------------------------------- #
@pytest.mark.parametrize(
"listen,multi_user,expected",
[
("127.0.0.1", False, True),
("127.0.0.1", True, False), # multi-user disables it
("0.0.0.0", False, False), # bind-all is not loopback
("0.0.0.0", True, False),
("192.168.1.5", False, False), # LAN address
("::1", False, True), # IPv6 loopback
],
)
def test_eligibility(listen, multi_user, expected, monkeypatch):
from app.model_downloader.hf_auth import eligibility
from comfy.cli_args import args
monkeypatch.setattr(args, "listen", listen)
monkeypatch.setattr(args, "multi_user", multi_user)
assert eligibility.is_hf_auth_eligible() is expected
# --------------------------------------------------------------------------- #
# gated_detection HF probe
# --------------------------------------------------------------------------- #
async def test_probe_url_hf_public(fresh_auth_store):
"""auth_check succeeds with no token → is_hf_downloadable = True."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch("app.model_downloader.gated_detection._auth_check_sync"), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=1024),
):
result = await probe_url(url)
assert result.is_hf_downloadable is True
assert result.file_size == 1024
async def test_probe_url_hf_gated_no_access(fresh_auth_store):
"""auth_check raises GatedRepoError → is_hf_downloadable = False."""
from huggingface_hub.errors import GatedRepoError
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
fake_response = MagicMock(status_code=403)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=GatedRepoError("gated", response=fake_response),
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
):
result = await probe_url(url)
assert result.is_hf_downloadable is False
async def test_probe_url_non_hf_skips_auth_check():
"""Non-HF URLs never call auth_check; is_hf_downloadable stays None."""
from app.model_downloader.gated_detection import probe_url
url = "https://civitai.com/api/download/models/1.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=2048),
):
result = await probe_url(url)
assert result.is_hf_downloadable is None
assert result.file_size == 2048
mocked.assert_not_called()
async def test_is_gated_cached_across_calls(fresh_auth_store):
"""Intrinsic ``is_gated`` should be determined exactly once per URL.
Subsequent ``probe_url`` calls for the same URL must not re-issue
the null-token auth_check that's the whole point of the cache."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync"
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=1024),
):
await probe_url(url)
await probe_url(url)
await probe_url(url)
# Three probe_url calls × public-only-needs-1-auth_check = 1 call total.
assert mocked.call_count == 1
async def test_file_size_cached_across_calls(fresh_auth_store):
"""Once a successful HEAD lands, subsequent calls don't re-HEAD."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync"
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=2048),
) as size_probe:
r1 = await probe_url(url)
r2 = await probe_url(url)
assert r1.file_size == 2048
assert r2.file_size == 2048
assert size_probe.call_count == 1
async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store):
"""When ``is_hf_downloadable`` is False we must NOT HEAD the URL —
otherwise a 401-due-to-gating would land as a cached ``None`` that
survives a later successful login."""
from app.model_downloader.gated_detection import probe_url
from huggingface_hub.errors import GatedRepoError
from unittest.mock import MagicMock
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
fake_resp = MagicMock(status_code=403)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=GatedRepoError("gated", response=fake_resp),
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
) as size_probe:
result = await probe_url(url)
assert result.is_hf_downloadable is False
assert result.file_size is None
assert size_probe.call_count == 0
async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir):
"""For a gated URL, auth_check runs twice: once with token=None to
determine the intrinsic ``is_gated`` flag (cached forever), and once
with the stored access_token to determine ``is_hf_downloadable`` for
the current user."""
from app.model_downloader import gated_detection
from app.model_downloader.gated_detection import probe_url
from huggingface_hub.errors import GatedRepoError
from unittest.mock import MagicMock
gated_detection.clear_caches_for_tests()
fresh_auth_store.set_token(Token(
access_token="hf_test_token",
refresh_token=None,
expires_at=9999999999.0,
))
url = "https://huggingface.co/private/repo/resolve/main/x.safetensors"
fake_resp = MagicMock(status_code=403)
def fake_auth_check(repo_id, token):
# Null-token call → repo is gated. Subsequent call with the real
# token succeeds (user has access).
if token is None:
raise GatedRepoError("gated", response=fake_resp)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=fake_auth_check,
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
):
result = await probe_url(url)
# is_hf_downloadable should be True (token-authed call succeeded).
assert result.is_hf_downloadable is True
# Two calls: (repo_id, None) then (repo_id, <token>).
assert mocked.call_count == 2
assert mocked.call_args_list[0].args == ("private/repo", None)
assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token")
# --------------------------------------------------------------------------- #
# OAuth primitives
# --------------------------------------------------------------------------- #
def test_make_pkce_returns_distinct_high_entropy_values():
verifier1, challenge1, state1 = oauth._make_pkce()
verifier2, challenge2, state2 = oauth._make_pkce()
assert verifier1 != verifier2
assert challenge1 != challenge2
assert state1 != state2
# Verifier should be at least 43 chars per PKCE spec.
assert len(verifier1) >= 43
def test_build_authorize_url_includes_pkce_and_state():
url = oauth._build_authorize_url("challenge123", "state456")
assert url.startswith(oauth.AUTHORIZE_URL)
assert "client_id=" + oauth.HF_CLIENT_ID in url
assert "code_challenge=challenge123" in url
assert "code_challenge_method=S256" in url
assert "state=state456" in url
assert "response_type=code" in url
# --------------------------------------------------------------------------- #
# Routes
# --------------------------------------------------------------------------- #
async def test_hf_auth_token_status_empty(aiohttp_client, app):
"""No token set → token_available=false, username=null."""
client = await aiohttp_client(app)
resp = await client.get("/api/hf-auth-token-status")
assert resp.status == 200
data = await resp.json()
assert data == {"token_available": False, "username": None}
async def test_hf_auth_token_status_with_token(
aiohttp_client, app, fresh_auth_store, patched_user_dir
):
"""Token present, whoami works → username is returned."""
fresh_auth_store.set_token(Token(
access_token="x", refresh_token=None, expires_at=9999999999.0,
))
with patch(
"app.model_downloader.api.routes._whoami_username",
return_value="alice",
):
client = await aiohttp_client(app)
resp = await client.get("/api/hf-auth-token-status")
assert resp.status == 200
assert (await resp.json()) == {"token_available": True, "username": "alice"}
async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch):
"""Not loopback / multi-user → 403."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: False,
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 403
assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE"
async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch):
"""Eligible + first attempt → 200 with authorize_url."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 200
assert (await resp.json())["authorize_url"].startswith(
"https://huggingface.co/oauth/authorize"
)
async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch):
"""Lock already held → 409."""
from app.model_downloader.hf_auth.oauth import OAuthInProgressError
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(side_effect=OAuthInProgressError()),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS"
async def test_hf_auth_logout_clears_store(
aiohttp_client, app, fresh_auth_store, patched_user_dir
):
fresh_auth_store.set_token(Token(
access_token="x", refresh_token=None, expires_at=9999999999.0,
))
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-logout")
assert resp.status == 200
assert (await resp.json()) == {"logged_out": True}
assert not fresh_auth_store.has_token()
async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch):
"""The availability response embeds {token_available, eligible}."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={"models": {}},
)
assert resp.status == 200
data = await resp.json()
assert "hf_auth" in data
assert data["hf_auth"] == {"token_available": False, "eligible": True}

View File

@ -0,0 +1,504 @@
"""Unit tests for the server-side model download subsystem.
Covers the pieces that don't require talking to a real network:
- path parsing & allowlist (pure functions)
- DownloadServer registry lifecycle (in-memory state)
- API routes via aiohttp_client + folder_paths/probe_url patches
Streaming downloads themselves are exercised indirectly the route-level
tests stub out the network probe so we can verify the gating logic in
``download_models`` without making real HTTP calls.
"""
from __future__ import annotations
import asyncio
import os
from unittest.mock import patch, AsyncMock
import pytest
from aiohttp import web
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.api.routes import register_routes
from app.model_downloader.download_server import DownloadServer
from app.model_downloader.gated_detection import MetadataProbeResult
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_destination,
resolve_existing,
)
# Global asyncio mark: the sync tests below trigger a cosmetic
# PytestWarning for each one because pytest-asyncio applies the mark
# indiscriminately. Other tests in this repo (see custom_node_manager_test.py)
# use the same pattern. The warnings are noise, not failures.
pytestmark = pytest.mark.asyncio
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def model_root(tmp_path):
"""A fake ``models/`` root with two registered folder types."""
loras_dir = tmp_path / "loras"
checkpoints_dir = tmp_path / "checkpoints"
loras_dir.mkdir()
checkpoints_dir.mkdir()
return tmp_path, loras_dir, checkpoints_dir
@pytest.fixture
def patched_folder_paths(model_root):
"""Point folder_paths at our fake roots for the duration of one test."""
_root, loras_dir, checkpoints_dir = model_root
mapping = {
"loras": ([str(loras_dir)], {".safetensors"}),
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
}
with patch(
"folder_paths.folder_names_and_paths", mapping
), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
yield mapping
@pytest.fixture
def fresh_download_server():
"""Reset the module-level singleton between tests so registry state
doesn't leak across tests sharing the singleton."""
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
yield DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
@pytest.fixture
def app(patched_folder_paths, fresh_download_server):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# Pure helpers: allowlist + path parsing
# --------------------------------------------------------------------------- #
def test_allowlist_accepts_hf_safetensors():
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
def test_allowlist_accepts_civitai_pth():
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
def test_allowlist_rejects_unknown_host():
assert not is_url_allowed("https://example.com/x.safetensors")
def test_allowlist_rejects_api_path_on_hf():
# On an allowlisted host but not pointing at a model file.
assert not is_url_allowed("https://huggingface.co/api/models")
def test_allowlist_rejects_non_https_except_localhost():
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
assert is_url_allowed("http://localhost:8000/x.safetensors")
def test_parse_model_id_valid(patched_folder_paths):
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
def test_parse_model_id_rejects_traversal(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("../etc/passwd")
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("nope/x.safetensors")
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("loras/sub/x.safetensors")
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
target = loras_dir / "foo.safetensors"
target.write_bytes(b"x")
assert resolve_existing("loras/foo.safetensors") == str(target)
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
assert resolve_existing("loras/missing.safetensors") is None
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
final, tmp = resolve_destination("loras/foo.safetensors")
assert final == str(loras_dir / "foo.safetensors")
assert tmp == final + ".tmp"
# --------------------------------------------------------------------------- #
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
# --------------------------------------------------------------------------- #
def test_register_is_exclusive():
server = DownloadServer()
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
assert s1 is not None
assert s2 is None
assert server.is_downloading("loras/x.safetensors")
def test_cancel_removes_session():
server = DownloadServer()
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
assert server.cancel("loras/x.safetensors") is True
assert not server.is_downloading("loras/x.safetensors")
def test_cancel_returns_false_when_absent():
server = DownloadServer()
assert server.cancel("loras/never.safetensors") is False
def test_finish_only_clears_matching_epoch():
"""If a session is cancelled and a new one for the same id is
registered, ``finish`` from the original worker must not evict the
newer session."""
server = DownloadServer()
s_old = server.try_register("loras/x.safetensors", "u1")
server.cancel("loras/x.safetensors")
s_new = server.try_register("loras/x.safetensors", "u2")
assert s_new is not None and s_new.epoch != s_old.epoch
# Old worker's late finish() is a no-op:
server.finish(s_old)
assert server.is_downloading("loras/x.safetensors")
server.finish(s_new)
assert not server.is_downloading("loras/x.safetensors")
def test_is_active_follows_cancellation():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
assert server.is_active(s)
server.cancel("loras/x.safetensors")
assert not server.is_active(s)
def test_update_progress_tracks_fraction():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, 100)
snap = server.snapshot()["loras/x.safetensors"]
assert snap.bytes_downloaded == 50
assert snap.total_bytes == 100
assert snap.progress == 0.5
def test_update_progress_with_unknown_total_keeps_progress_none():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, None)
assert server.snapshot()["loras/x.safetensors"].progress is None
def test_cleanup_orphan_tmp_files(model_root):
"""Orphan .tmp left by a crashed download must be swept on first use."""
_root, loras_dir, _ = model_root
orphan = loras_dir / "stale.safetensors.tmp"
orphan.write_bytes(b"partial")
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
with patch("folder_paths.folder_names_and_paths", mapping), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
server = DownloadServer()
assert orphan.exists(), "sweep must not run at construction time"
server.sweep_orphan_tmp_files()
assert not orphan.exists()
# Idempotent — a second call is a cheap no-op.
server.sweep_orphan_tmp_files()
# --------------------------------------------------------------------------- #
# Route: POST /api/models-availability-status
# --------------------------------------------------------------------------- #
async def test_availability_partitions_correctly(
aiohttp_client, app, model_root, fresh_download_server
):
_root, loras_dir, _ = model_root
(loras_dir / "present.safetensors").write_bytes(b"x")
fresh_download_server.try_register(
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
)
client = await aiohttp_client(app)
# Stub probes — we're testing state assignment, not network calls.
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
body = {
"models": {
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
}
}
resp = await client.post("/api/models-availability-status", json=body)
assert resp.status == 200
data = await resp.json()
models = data["models"]
assert models["loras/present.safetensors"]["state"] == "available"
assert models["loras/missing.safetensors"]["state"] == "missing"
assert models["loras/inflight.safetensors"]["state"] == "downloading"
assert "hf_auth" in data
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
resp = await client.post(
"/api/models-availability-status",
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
)
assert resp.status == 200
data = await resp.json()
assert data["models"]["../etc/passwd"]["state"] == "missing"
# --------------------------------------------------------------------------- #
# Route: POST /api/download-models — precondition gating
# --------------------------------------------------------------------------- #
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
)
assert resp.status == 400
err = (await resp.json())["error"]
assert err["code"] == "URL_NOT_ALLOWED"
async def test_download_rejects_already_available(
aiohttp_client, app, model_root
):
_root, loras_dir, _ = model_root
(loras_dir / "x.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
async def test_download_rejects_already_downloading(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
async def test_download_rejects_gated_model(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
):
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
async def test_download_atomic_failure_does_not_register_partial(
aiohttp_client, app, model_root, fresh_download_server
):
"""If one model in a batch fails, none get registered."""
_root, loras_dir, _ = model_root
(loras_dir / "already.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={
"models": {
"loras/already.safetensors":
"https://huggingface.co/a/b/resolve/main/already.safetensors",
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors",
}
},
)
assert resp.status == 409
# The "new" model should not have been registered as part of the
# failed batch.
assert not fresh_download_server.is_downloading("loras/new.safetensors")
async def test_download_schedules_when_all_preconditions_pass(
aiohttp_client, app, fresh_download_server
):
"""Verify the precondition pass, registration pass, and async
scheduling all wire up correctly. We patch the streamer to avoid
real HTTP while still letting the route execute end-to-end."""
started = asyncio.Event()
finish_signal = asyncio.Event()
async def fake_stream(session):
started.set()
await finish_signal.wait()
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.finish(session)
return "/dev/null"
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
), patch(
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors"
}},
)
assert resp.status == 202
body = await resp.json()
assert body["accepted"] is True
assert body["scheduled"] == ["loras/new.safetensors"]
# Wait for the worker to actually start.
await asyncio.wait_for(started.wait(), timeout=2.0)
assert fresh_download_server.is_downloading("loras/new.safetensors")
finish_signal.set()
# --------------------------------------------------------------------------- #
# Route: POST /api/cancel-model-download-session
# --------------------------------------------------------------------------- #
async def test_cancel_removes_active_session(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/x.safetensors"},
)
assert resp.status == 200
assert (await resp.json())["cancelled"] is True
assert not fresh_download_server.is_downloading("loras/x.safetensors")
async def test_cancel_returns_404_when_none(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/nothing.safetensors"},
)
assert resp.status == 404
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
# --------------------------------------------------------------------------- #
# Unified availability response embeds metadata per id
# --------------------------------------------------------------------------- #
async def test_availability_embeds_metadata(aiohttp_client, app):
"""``file_size`` + ``is_hf_downloadable`` come back on the same
request as the state no separate metadata endpoint."""
results = {
"https://huggingface.co/a/b/resolve/main/free.safetensors":
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
}
async def fake_probe(url):
return results[url]
with patch(
"app.model_downloader.api.routes.probe_url", new=fake_probe
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={
"models": {
"loras/free.safetensors":
"https://huggingface.co/a/b/resolve/main/free.safetensors",
"loras/gated.safetensors":
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
}
},
)
assert resp.status == 200
models = (await resp.json())["models"]
assert models["loras/free.safetensors"]["file_size"] == 1024
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
assert models["loras/gated.safetensors"]["file_size"] is None
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False