mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 08:19:32 +08:00
Merge a337561802 into 6978a466b8
This commit is contained in:
commit
58c5cef538
46
app/model_downloader/allowlist.py
Normal file
46
app/model_downloader/allowlist.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""URL allowlist for server-side model fetches.
|
||||
|
||||
Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows
|
||||
agree on which URLs are eligible for download. Server-side allowlisting is
|
||||
the primary SSRF defense for this subsystem — workflow JSON is untrusted
|
||||
input (anyone can hand-craft one), so we never let the server fetch URLs
|
||||
outside this list.
|
||||
"""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Frontend parity: ``missingModelDownload-*.js`` exports the same triple as
|
||||
# ``i = [...]`` (Civitai / HuggingFace / localhost).
|
||||
_ALLOWED_URL_PREFIXES = (
|
||||
"https://huggingface.co/",
|
||||
"https://civitai.com/",
|
||||
"http://localhost:",
|
||||
"http://127.0.0.1:",
|
||||
)
|
||||
|
||||
# Frontend parity: same set as ``a = [...]`` in the bundle.
|
||||
_ALLOWED_MODEL_EXTENSIONS = (
|
||||
".safetensors",
|
||||
".sft",
|
||||
".ckpt",
|
||||
".pth",
|
||||
".pt",
|
||||
)
|
||||
|
||||
|
||||
def is_url_allowed(url: str) -> bool:
|
||||
"""Check whether ``url`` is permitted as a server-side download source.
|
||||
|
||||
Returns True only when both:
|
||||
- the URL starts with one of the allowed prefixes, AND
|
||||
- the URL's final path segment ends with a known model extension.
|
||||
|
||||
Both checks are required to keep arbitrary HTML / API endpoints on
|
||||
allowlisted hosts (e.g. ``https://huggingface.co/api/...``) off the table.
|
||||
"""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
if not any(url.startswith(p) for p in _ALLOWED_URL_PREFIXES):
|
||||
return False
|
||||
path = urlparse(url).path
|
||||
return any(path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS)
|
||||
332
app/model_downloader/api/routes.py
Normal file
332
app/model_downloader/api/routes.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Aiohttp routes for the server-side model download subsystem.
|
||||
|
||||
Endpoint surface (all under ``/api/``, all kebab-case):
|
||||
|
||||
- ``POST /api/models-availability-status`` — bulk status + metadata query.
|
||||
- ``POST /api/download-models`` — start a batch of downloads.
|
||||
- ``POST /api/cancel-model-download-session`` — cancel a single in-flight one.
|
||||
- ``GET /api/hf-auth-token-status`` — current HF login state.
|
||||
- ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow.
|
||||
- ``POST /api/hf-auth-logout`` — drop the stored HF token.
|
||||
|
||||
The contract is intentionally narrow: only model_ids of the form
|
||||
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
|
||||
are accepted, and only URLs on the same allowlist the frontend already
|
||||
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
|
||||
to keep the server out of the SSRF business for this feature.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.model_downloader.allowlist import is_url_allowed
|
||||
from app.model_downloader.download_server import (
|
||||
DOWNLOAD_SERVER,
|
||||
DownloadSession,
|
||||
)
|
||||
from app.model_downloader.downloader import schedule_batch
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
|
||||
from app.model_downloader.hf_auth.oauth import (
|
||||
OAuthInProgressError,
|
||||
start_login_flow,
|
||||
)
|
||||
from app.model_downloader.paths import (
|
||||
InvalidModelId,
|
||||
parse_model_id,
|
||||
resolve_existing,
|
||||
)
|
||||
from app.model_downloader.api import schemas_in, schemas_out
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Wire the model-downloader routes into the running aiohttp app.
|
||||
|
||||
Called once from ``server.py`` during ``PromptServer`` startup.
|
||||
"""
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
|
||||
|
||||
|
||||
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _validation_error(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
|
||||
|
||||
|
||||
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
|
||||
return web.json_response(
|
||||
payload.model_dump(mode="json", exclude_none=False),
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
|
||||
"""Parse a JSON body into a pydantic model or raise a 400 response."""
|
||||
try:
|
||||
raw = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
try:
|
||||
return model.model_validate(raw)
|
||||
except ValidationError as ve:
|
||||
return _validation_error("INVALID_BODY", ve)
|
||||
|
||||
|
||||
# ----- 1. availability status (unified: state + metadata per id) -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/models-availability-status")
|
||||
async def models_availability_status(request: web.Request) -> web.Response:
|
||||
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
|
||||
|
||||
State (``available`` / ``missing`` / ``downloading``) is cheap to
|
||||
recompute per call. ``file_size`` and ``is_gated`` are cached
|
||||
server-side per URL. ``is_hf_downloadable`` is recomputed every
|
||||
call from the current token state — that's what makes login + license
|
||||
acceptance show up in the UI within one poll cycle without any
|
||||
frontend cache plumbing.
|
||||
"""
|
||||
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
|
||||
items = list(parsed.models.items())
|
||||
|
||||
# Run all probes concurrently; each is internally cached per URL.
|
||||
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
|
||||
|
||||
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
|
||||
for (model_id, _url), probe in zip(items, probes):
|
||||
try:
|
||||
parse_model_id(model_id)
|
||||
except InvalidModelId:
|
||||
# Ill-formed identifier: report as missing without 400-ing the
|
||||
# whole batch — the workflow author probably typo'd.
|
||||
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||
state="missing",
|
||||
file_size=probe.file_size,
|
||||
is_hf_downloadable=probe.is_hf_downloadable,
|
||||
)
|
||||
continue
|
||||
|
||||
active = DOWNLOAD_SERVER.get(model_id)
|
||||
if active is not None:
|
||||
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||
state="downloading",
|
||||
progress=schemas_out.DownloadProgress(
|
||||
bytes_downloaded=active.bytes_downloaded,
|
||||
total_bytes=active.total_bytes,
|
||||
progress=active.progress,
|
||||
),
|
||||
file_size=probe.file_size,
|
||||
is_hf_downloadable=probe.is_hf_downloadable,
|
||||
)
|
||||
continue
|
||||
|
||||
state: schemas_out.ModelState = (
|
||||
"available" if resolve_existing(model_id) is not None else "missing"
|
||||
)
|
||||
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||
state=state,
|
||||
file_size=probe.file_size,
|
||||
is_hf_downloadable=probe.is_hf_downloadable,
|
||||
)
|
||||
|
||||
return _ok(schemas_out.AvailabilityStatusResponse(
|
||||
models=response_models,
|
||||
hf_auth=schemas_out.HfAuthStatus(
|
||||
token_available=HF_AUTH_STORE.has_token(),
|
||||
eligible=is_hf_auth_eligible(),
|
||||
),
|
||||
))
|
||||
|
||||
|
||||
# ----- 2. start downloads -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download-models")
|
||||
async def download_models(request: web.Request) -> web.Response:
|
||||
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
|
||||
if not parsed.models:
|
||||
return _error(400, "EMPTY_REQUEST", "No models supplied.")
|
||||
|
||||
# ----- precondition pass: validate everything BEFORE registering anything -----
|
||||
# Atomic semantics: if any model fails any precondition (invalid id,
|
||||
# not allow-listed URL, already on disk, already downloading, or gated),
|
||||
# the entire request fails and no state is changed.
|
||||
requested = list(parsed.models.items())
|
||||
|
||||
for model_id, url in requested:
|
||||
try:
|
||||
parse_model_id(model_id)
|
||||
except InvalidModelId as e:
|
||||
return _error(400, "INVALID_MODEL_ID", str(e),
|
||||
{"model_id": model_id})
|
||||
|
||||
if not is_url_allowed(url):
|
||||
return _error(
|
||||
400, "URL_NOT_ALLOWED",
|
||||
"Server-side downloads only accept HuggingFace, Civitai, "
|
||||
"or localhost URLs ending in a known model extension.",
|
||||
{"model_id": model_id, "url": url},
|
||||
)
|
||||
|
||||
if resolve_existing(model_id) is not None:
|
||||
return _error(409, "ALREADY_AVAILABLE",
|
||||
f"Model already exists on disk: {model_id}",
|
||||
{"model_id": model_id})
|
||||
|
||||
if DOWNLOAD_SERVER.is_downloading(model_id):
|
||||
return _error(409, "ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress.",
|
||||
{"model_id": model_id})
|
||||
|
||||
# Reachability check last — it's the only one that talks to the
|
||||
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
|
||||
# reflects current token access; for non-HF URLs it's None, and we
|
||||
# treat that as "no info, proceed".
|
||||
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
|
||||
for (model_id, url), probe in zip(requested, probes):
|
||||
if probe.is_hf_downloadable is False:
|
||||
return _error(
|
||||
400, "MODEL_NOT_DOWNLOADABLE",
|
||||
f"Model {model_id} is gated on HuggingFace and the current "
|
||||
f"server token (if any) does not grant access.",
|
||||
{"model_id": model_id, "url": url},
|
||||
)
|
||||
|
||||
# ----- registration pass: try_register is atomic per model_id -----
|
||||
# Defensive: another request might have raced past our pre-check
|
||||
# between the loop above and here. try_register handles that.
|
||||
sessions: list[DownloadSession] = []
|
||||
for model_id, url in requested:
|
||||
session = DOWNLOAD_SERVER.try_register(model_id, url)
|
||||
if session is None:
|
||||
# Race: someone else got in. Roll back what we registered.
|
||||
for s in sessions:
|
||||
DOWNLOAD_SERVER.cancel(s.model_id)
|
||||
return _error(409, "ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress (race).",
|
||||
{"model_id": model_id})
|
||||
sessions.append(session)
|
||||
|
||||
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
|
||||
schedule_batch(sessions)
|
||||
logging.info(
|
||||
"[model_downloader] scheduled %d downloads: %s",
|
||||
len(sessions), [s.model_id for s in sessions],
|
||||
)
|
||||
|
||||
return _ok(schemas_out.DownloadModelsResponse(
|
||||
accepted=True,
|
||||
scheduled=[s.model_id for s in sessions],
|
||||
), status=202)
|
||||
|
||||
|
||||
# ----- 3. cancel a session -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/cancel-model-download-session")
|
||||
async def cancel_model_download_session(request: web.Request) -> web.Response:
|
||||
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
|
||||
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
|
||||
if not cancelled:
|
||||
return _error(404, "NOT_DOWNLOADING",
|
||||
f"No active download for {parsed.model_id}.",
|
||||
{"model_id": parsed.model_id})
|
||||
|
||||
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
|
||||
|
||||
|
||||
# ----- 4. HuggingFace OAuth status / login start / logout -----
|
||||
|
||||
|
||||
@ROUTES.get("/api/hf-auth-token-status")
|
||||
async def hf_auth_token_status(request: web.Request) -> web.Response:
|
||||
"""Return whether the server holds a usable HF token + its username.
|
||||
|
||||
Used by the settings UI and (out-of-band) by the frontend on
|
||||
login completion. ``token_available`` is true even if the cached
|
||||
access_token is expired — as long as a refresh_token exists, the
|
||||
user is "logged in" from their perspective.
|
||||
"""
|
||||
token_present = HF_AUTH_STORE.has_token()
|
||||
username: Optional[str] = None
|
||||
if token_present:
|
||||
# Resolve the username via whoami. Done in a worker thread because
|
||||
# huggingface_hub's whoami is synchronous + blocks on a network call.
|
||||
tok = await HF_AUTH_STORE.get_valid_token()
|
||||
if tok is not None:
|
||||
try:
|
||||
username = await asyncio.to_thread(_whoami_username, tok.access_token)
|
||||
except Exception as e:
|
||||
logging.debug("[hf_auth] whoami failed: %s", e)
|
||||
return _ok(schemas_out.HfAuthTokenStatusResponse(
|
||||
token_available=token_present,
|
||||
username=username,
|
||||
))
|
||||
|
||||
|
||||
def _whoami_username(token: str) -> Optional[str]:
|
||||
"""Sync helper: ask HF for the user name attached to a token."""
|
||||
from huggingface_hub import HfApi
|
||||
info = HfApi().whoami(token=token)
|
||||
if isinstance(info, dict):
|
||||
return info.get("name") or info.get("fullname")
|
||||
return None
|
||||
|
||||
|
||||
@ROUTES.post("/api/hf-auth-login-start")
|
||||
async def hf_auth_login_start(request: web.Request) -> web.Response:
|
||||
"""Begin one OAuth attempt: bind the callback port, return the URL.
|
||||
|
||||
Rejected outright if this deployment isn't eligible (we don't
|
||||
surface the option on multi-tenant / public-IP installs).
|
||||
"""
|
||||
if not is_hf_auth_eligible():
|
||||
return _error(
|
||||
403, "HF_AUTH_NOT_ELIGIBLE",
|
||||
"This server is not eligible for interactive HuggingFace login. "
|
||||
"It must be bound to a loopback address and not running in "
|
||||
"--multi-user mode.",
|
||||
)
|
||||
try:
|
||||
url = await start_login_flow()
|
||||
except OAuthInProgressError:
|
||||
return _error(
|
||||
409, "HF_AUTH_IN_PROGRESS",
|
||||
"Another HuggingFace login attempt is in progress. Try again "
|
||||
"after it completes or times out.",
|
||||
)
|
||||
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
|
||||
|
||||
|
||||
@ROUTES.post("/api/hf-auth-logout")
|
||||
async def hf_auth_logout(request: web.Request) -> web.Response:
|
||||
"""Drop the in-memory + on-disk HF token."""
|
||||
HF_AUTH_STORE.clear()
|
||||
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))
|
||||
41
app/model_downloader/api/schemas_in.py
Normal file
41
app/model_downloader/api/schemas_in.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Request schemas for the model-downloader API.
|
||||
|
||||
Each endpoint accepts a small JSON body. Pydantic enforces the shape at
|
||||
the boundary; route handlers operate only on validated values past that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AvailabilityStatusRequest(BaseModel):
|
||||
"""``POST /api/models-availability-status``.
|
||||
|
||||
Sent by the frontend on each poll. Each entry is ``{model_id: url}``;
|
||||
the URL is the one declared in ``properties.models[i].url`` in the
|
||||
workflow JSON and lets the server compute per-id metadata
|
||||
(``file_size`` + ``is_hf_downloadable``) on the same request.
|
||||
"""
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DownloadModelsRequest(BaseModel):
|
||||
"""``POST /api/download-models``.
|
||||
|
||||
Same shape as the metadata request — the URL for each model_id.
|
||||
Returns immediately after validation and scheduling.
|
||||
"""
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CancelDownloadSessionRequest(BaseModel):
|
||||
"""``POST /api/cancel-model-download-session``."""
|
||||
model_id: str
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AvailabilityStatusRequest",
|
||||
"DownloadModelsRequest",
|
||||
"CancelDownloadSessionRequest",
|
||||
]
|
||||
81
app/model_downloader/api/schemas_out.py
Normal file
81
app/model_downloader/api/schemas_out.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Response schemas for the model-downloader API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
ModelState = Literal["available", "missing", "downloading"]
|
||||
|
||||
|
||||
class DownloadProgress(BaseModel):
|
||||
"""Embedded in a model entry when its state is ``downloading``."""
|
||||
bytes_downloaded: int
|
||||
total_bytes: Optional[int] = None
|
||||
progress: Optional[float] = None # fraction in [0,1]; null until total known
|
||||
|
||||
|
||||
class ModelStatusEntry(BaseModel):
|
||||
"""Everything the UI needs to render one row, in one shot.
|
||||
|
||||
``state`` reflects what the server has on disk + in-flight; ``file_size``
|
||||
and ``is_hf_downloadable`` come from probes (intrinsic; cached).
|
||||
The HF fields are populated for every poll (cached on the server),
|
||||
so license-acceptance flips show up within one poll interval without
|
||||
any frontend cache invalidation.
|
||||
"""
|
||||
state: ModelState
|
||||
progress: Optional[DownloadProgress] = None
|
||||
file_size: Optional[int] = None
|
||||
# HF-only: True iff the server can fetch this URL with current auth
|
||||
# state. False iff gated and lacking access. None for non-HF URLs.
|
||||
is_hf_downloadable: Optional[bool] = None
|
||||
|
||||
|
||||
class HfAuthStatus(BaseModel):
|
||||
"""Snapshot of HF login state, embedded in availability response."""
|
||||
token_available: bool
|
||||
eligible: bool
|
||||
|
||||
|
||||
class AvailabilityStatusResponse(BaseModel):
|
||||
models: dict[str, ModelStatusEntry]
|
||||
hf_auth: HfAuthStatus
|
||||
|
||||
|
||||
class DownloadModelsResponse(BaseModel):
|
||||
accepted: bool
|
||||
scheduled: list[str]
|
||||
|
||||
|
||||
class CancelDownloadSessionResponse(BaseModel):
|
||||
cancelled: bool
|
||||
|
||||
|
||||
class HfAuthTokenStatusResponse(BaseModel):
|
||||
token_available: bool
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class HfAuthLoginStartResponse(BaseModel):
|
||||
authorize_url: str
|
||||
|
||||
|
||||
class HfAuthLogoutResponse(BaseModel):
|
||||
logged_out: bool
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModelState",
|
||||
"DownloadProgress",
|
||||
"ModelStatusEntry",
|
||||
"HfAuthStatus",
|
||||
"AvailabilityStatusResponse",
|
||||
"DownloadModelsResponse",
|
||||
"CancelDownloadSessionResponse",
|
||||
"HfAuthTokenStatusResponse",
|
||||
"HfAuthLoginStartResponse",
|
||||
"HfAuthLogoutResponse",
|
||||
]
|
||||
179
app/model_downloader/download_server.py
Normal file
179
app/model_downloader/download_server.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""Process-wide registry of in-flight model downloads.
|
||||
|
||||
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
|
||||
server-side model fetch. Designed to be safe with multiple concurrent
|
||||
clients hitting the API: each model_id has at most one active session,
|
||||
and the API rejects requests that conflict with in-flight downloads.
|
||||
|
||||
Cancellation is cooperative — the download loop checks ``is_active`` on
|
||||
its own session between chunks and raises ``DownloadCancelled`` when the
|
||||
session has been removed from the registry. This avoids the complications
|
||||
of ``Task.cancel()`` from outside the loop while still giving deterministic
|
||||
rollback semantics (the worker is responsible for deleting its own
|
||||
``.tmp`` on the cancel path).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from app.model_downloader.paths import iter_all_tmp_paths
|
||||
|
||||
|
||||
class DownloadCancelled(Exception):
|
||||
"""Raised by the streaming loop when its session has been removed
|
||||
from the registry (cancellation request) and the worker should roll
|
||||
back its ``.tmp`` file."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadSession:
|
||||
"""One in-flight download.
|
||||
|
||||
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
|
||||
byte arrives and we know whether the response carries a
|
||||
``Content-Length``. ``total_bytes`` mirrors that header when present.
|
||||
"""
|
||||
model_id: str
|
||||
url: str
|
||||
progress: Optional[float] = None
|
||||
bytes_downloaded: int = 0
|
||||
total_bytes: Optional[int] = None
|
||||
# Sequence number used solely as identity for the cancellation check —
|
||||
# so that "cancel + restart" doesn't get confused by stale workers.
|
||||
epoch: int = field(default_factory=lambda: 0)
|
||||
|
||||
|
||||
class DownloadServer:
|
||||
"""Singleton registry of active downloads.
|
||||
|
||||
All mutation goes through this object so concurrent route handlers
|
||||
see a consistent view. The ``_lock`` is a plain threading lock
|
||||
because the registry is consulted from both the asyncio event-loop
|
||||
thread (route handlers) and from any worker coroutines spawned to
|
||||
perform downloads.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: dict[str, DownloadSession] = {}
|
||||
self._epoch_counter = 0
|
||||
self._orphan_sweep_done = False
|
||||
|
||||
# ----- lifecycle -----
|
||||
|
||||
def sweep_orphan_tmp_files(self) -> None:
|
||||
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
|
||||
|
||||
Deferred off the import path so module load doesn't block on
|
||||
filesystem I/O against potentially-slow mounts. Each route handler
|
||||
that might create a new ``.tmp`` runs this exactly once.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._orphan_sweep_done:
|
||||
return
|
||||
self._orphan_sweep_done = True
|
||||
for path in iter_all_tmp_paths():
|
||||
try:
|
||||
os.remove(path)
|
||||
logging.info("[model_downloader] removed orphan tmp file: %s", path)
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
# ----- queries -----
|
||||
|
||||
def is_downloading(self, model_id: str) -> bool:
|
||||
with self._lock:
|
||||
return model_id in self._sessions
|
||||
|
||||
def get(self, model_id: str) -> Optional[DownloadSession]:
|
||||
with self._lock:
|
||||
return self._sessions.get(model_id)
|
||||
|
||||
def snapshot(self) -> dict[str, DownloadSession]:
|
||||
"""Return a shallow copy of the current sessions map."""
|
||||
with self._lock:
|
||||
return dict(self._sessions)
|
||||
|
||||
# ----- mutations -----
|
||||
|
||||
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
|
||||
"""Atomically register a new session iff none exists for ``model_id``.
|
||||
|
||||
Returns the new session on success, ``None`` if a session is already
|
||||
in flight. Callers must check the return value — the caller is the
|
||||
sole owner of the session it gets back.
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self._sessions:
|
||||
return None
|
||||
self._epoch_counter += 1
|
||||
session = DownloadSession(
|
||||
model_id=model_id,
|
||||
url=url,
|
||||
epoch=self._epoch_counter,
|
||||
)
|
||||
self._sessions[model_id] = session
|
||||
return session
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
session: DownloadSession,
|
||||
bytes_downloaded: int,
|
||||
total_bytes: Optional[int],
|
||||
) -> None:
|
||||
"""Update progress on a session. No-op if the session has been
|
||||
removed (cancelled) — caller should check ``is_active`` separately."""
|
||||
with self._lock:
|
||||
current = self._sessions.get(session.model_id)
|
||||
if current is None or current.epoch != session.epoch:
|
||||
return
|
||||
current.bytes_downloaded = bytes_downloaded
|
||||
current.total_bytes = total_bytes
|
||||
if total_bytes and total_bytes > 0:
|
||||
current.progress = min(1.0, bytes_downloaded / total_bytes)
|
||||
|
||||
def is_active(self, session: DownloadSession) -> bool:
|
||||
"""True iff this exact session is still the registered one for
|
||||
its model_id. False after cancellation, after completion, or if
|
||||
another session has replaced it."""
|
||||
with self._lock:
|
||||
current = self._sessions.get(session.model_id)
|
||||
return current is not None and current.epoch == session.epoch
|
||||
|
||||
def finish(self, session: DownloadSession) -> None:
|
||||
"""Remove a completed (or cancelled) session from the registry.
|
||||
|
||||
Safe to call multiple times. Only removes if the epoch matches —
|
||||
we never accidentally evict a *newer* session for the same model_id.
|
||||
"""
|
||||
with self._lock:
|
||||
current = self._sessions.get(session.model_id)
|
||||
if current is not None and current.epoch == session.epoch:
|
||||
del self._sessions[session.model_id]
|
||||
|
||||
def reset_for_tests(self) -> None:
|
||||
"""Clear all sessions and reset the epoch counter. Test-only."""
|
||||
with self._lock:
|
||||
self._sessions.clear()
|
||||
self._epoch_counter = 0
|
||||
|
||||
def cancel(self, model_id: str) -> bool:
|
||||
"""Remove the session registered for ``model_id``.
|
||||
|
||||
Returns True if there was an active session to cancel. The worker
|
||||
will discover the cancellation on its next ``is_active`` check
|
||||
and roll back its ``.tmp`` file.
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self._sessions:
|
||||
del self._sessions[model_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
DOWNLOAD_SERVER = DownloadServer()
|
||||
205
app/model_downloader/downloader.py
Normal file
205
app/model_downloader/downloader.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Streaming download worker with progress reporting and cancellation.
|
||||
|
||||
Each download writes to ``<final_path>.tmp`` and atomically renames into
|
||||
place on success. Between chunks the worker checks the registry for
|
||||
cancellation (via ``DownloadServer.is_active``) and rolls back its
|
||||
``.tmp`` on cancel or on any error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.download_server import (
|
||||
DOWNLOAD_SERVER,
|
||||
DownloadCancelled,
|
||||
DownloadSession,
|
||||
)
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_url import is_hf_url
|
||||
from app.model_downloader.http_client import get_session, parse_content_length
|
||||
from app.model_downloader.paths import resolve_destination
|
||||
|
||||
|
||||
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
|
||||
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
|
||||
|
||||
|
||||
async def stream_to_disk(session: DownloadSession) -> str:
|
||||
"""Run a single download to completion or cancellation.
|
||||
|
||||
Returns the final on-disk path on success. Removes the ``.tmp`` and
|
||||
raises on cancellation or failure. The session is finished
|
||||
(removed from the registry) exactly once, here — callers do not
|
||||
need to call ``DOWNLOAD_SERVER.finish`` themselves.
|
||||
"""
|
||||
final_path, tmp_path = resolve_destination(session.model_id)
|
||||
os.makedirs(os.path.dirname(final_path), exist_ok=True)
|
||||
|
||||
# Wipe any stale .tmp from a previous failed attempt before we start —
|
||||
# otherwise a partial body could masquerade as our completed download
|
||||
# when the rename finally happens.
|
||||
_remove_if_exists(tmp_path)
|
||||
|
||||
bytes_seen = 0
|
||||
try:
|
||||
http = await get_session()
|
||||
headers = _auth_headers_for(session.url)
|
||||
logging.info(
|
||||
"[model_downloader] starting GET %s (auth=%s)",
|
||||
session.url, "yes" if "Authorization" in headers else "no",
|
||||
)
|
||||
async with http.get(
|
||||
session.url,
|
||||
allow_redirects=True,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
# Capture a snippet of the response body so 4xx/5xx aren't
|
||||
# opaque in the logs — HF returns JSON or HTML with a
|
||||
# human-readable reason on failures.
|
||||
body_snippet = await _read_short(resp)
|
||||
logging.warning(
|
||||
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
|
||||
session.url, resp.status, str(resp.url), body_snippet,
|
||||
)
|
||||
raise DownloadError(
|
||||
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
|
||||
status=resp.status,
|
||||
)
|
||||
|
||||
total = parse_content_length(resp.headers.get("Content-Length"))
|
||||
DOWNLOAD_SERVER.update_progress(session, 0, total)
|
||||
|
||||
with open(tmp_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
|
||||
# Cancellation check between chunks. Cheap and means
|
||||
# cancellation latency is bounded by one chunk plus
|
||||
# one ``write()`` — typically well under a second
|
||||
# even on slow disks.
|
||||
if not DOWNLOAD_SERVER.is_active(session):
|
||||
raise DownloadCancelled()
|
||||
f.write(chunk)
|
||||
bytes_seen += len(chunk)
|
||||
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
|
||||
|
||||
# Final cancellation check before we promote the .tmp to the real
|
||||
# filename — avoids the awkward case where cancel arrives during
|
||||
# the very last chunk and we'd otherwise commit anyway.
|
||||
if not DOWNLOAD_SERVER.is_active(session):
|
||||
raise DownloadCancelled()
|
||||
|
||||
# Atomic rename. os.replace is atomic within the same filesystem,
|
||||
# which is guaranteed here because tmp lives alongside final_path.
|
||||
os.replace(tmp_path, final_path)
|
||||
logging.info(
|
||||
"[model_downloader] downloaded %s (%d bytes) from %s",
|
||||
session.model_id, bytes_seen, session.url,
|
||||
)
|
||||
return final_path
|
||||
|
||||
except DownloadCancelled:
|
||||
logging.info("[model_downloader] cancelled: %s", session.model_id)
|
||||
_remove_if_exists(tmp_path)
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"[model_downloader] failed: %s from %s: %s: %s",
|
||||
session.model_id, session.url, type(e).__name__, e,
|
||||
exc_info=True,
|
||||
)
|
||||
_remove_if_exists(tmp_path)
|
||||
raise
|
||||
finally:
|
||||
# In all terminal states (success / cancel / error) drop the
|
||||
# session from the registry. Idempotent — only removes if we're
|
||||
# still the live epoch for this model_id.
|
||||
DOWNLOAD_SERVER.finish(session)
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
"""Network / protocol error during a download."""
|
||||
|
||||
def __init__(self, message: str, status: Optional[int] = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
|
||||
|
||||
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
|
||||
"""Read up to ``limit`` bytes of a response body for logging.
|
||||
|
||||
Used to surface the JSON/HTML reason from an HF non-2xx response in
|
||||
server logs instead of just the status code. Best-effort: any
|
||||
error here is swallowed.
|
||||
"""
|
||||
try:
|
||||
raw = await resp.content.read(limit)
|
||||
return raw.decode("utf-8", errors="replace").strip()
|
||||
except Exception:
|
||||
return "<unreadable>"
|
||||
|
||||
|
||||
def _auth_headers_for(url: str) -> dict[str, str]:
|
||||
"""Return any auth headers we should add to the GET for ``url``.
|
||||
|
||||
For HuggingFace URLs we inject the user's OAuth access token as a
|
||||
Bearer header — this is HF's documented way to access gated repos
|
||||
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
|
||||
other host we send no extra headers; allowlisted public files
|
||||
don't need them and we don't want to leak tokens to other hosts.
|
||||
"""
|
||||
if not is_hf_url(url):
|
||||
return {}
|
||||
tok = HF_AUTH_STORE.get_token_sync()
|
||||
if tok is None or not tok.access_token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {tok.access_token}"}
|
||||
|
||||
|
||||
def _remove_if_exists(path: str) -> None:
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
|
||||
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
|
||||
"""Run a list of sessions one after the other.
|
||||
|
||||
Each session is independent: a failure or cancellation of one does
|
||||
not abort the rest. Cancellations are observable via the registry
|
||||
*before* a given download starts, so a session that's been
|
||||
pre-cancelled (cancel before the worker reached it) just gets skipped.
|
||||
"""
|
||||
for session in sessions:
|
||||
# If the session got cancelled before its turn, skip without
|
||||
# touching disk. This is what makes the per-request "sequential
|
||||
# but cancellable" semantic work.
|
||||
if not DOWNLOAD_SERVER.is_active(session):
|
||||
DOWNLOAD_SERVER.finish(session)
|
||||
continue
|
||||
try:
|
||||
await stream_to_disk(session)
|
||||
except DownloadCancelled:
|
||||
# Already logged + tmp removed inside stream_to_disk.
|
||||
continue
|
||||
except Exception:
|
||||
# stream_to_disk already logged. Continue with the rest of the batch.
|
||||
continue
|
||||
|
||||
|
||||
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
|
||||
"""Kick off ``run_batch_sequential`` on the running event loop.
|
||||
|
||||
Returned task is fire-and-forget; the API handler returns immediately
|
||||
after scheduling and clients observe progress via the polling endpoints.
|
||||
"""
|
||||
return asyncio.create_task(run_batch_sequential(sessions))
|
||||
238
app/model_downloader/gated_detection.py
Normal file
238
app/model_downloader/gated_detection.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""Per-URL probes for the unified availability endpoint.
|
||||
|
||||
Three cached/derived facts per URL:
|
||||
|
||||
- ``is_gated`` intrinsic to the model; cached forever once known.
|
||||
Determined by ``auth_check(repo_id, token=None)``:
|
||||
``GatedRepoError`` → True, success → False.
|
||||
|
||||
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
|
||||
call. For non-gated URLs this is trivially True
|
||||
(no HF call needed). For gated URLs we run
|
||||
``auth_check`` with the stored token each call.
|
||||
|
||||
- ``file_size`` intrinsic to the file. Cached forever once
|
||||
determined (including ``None`` on transient
|
||||
failure — we don't retry). We only attempt the
|
||||
HEAD when we already know the URL is downloadable
|
||||
to us; that way a failed-because-gated probe
|
||||
never lands as a cached ``None``.
|
||||
|
||||
Caches are per-process, in-memory; small, no eviction needed for the
|
||||
workflow-scale (~tens of URLs). Concurrent calls for the same URL
|
||||
deduplicate via per-URL ``asyncio.Lock``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import (
|
||||
GatedRepoError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
)
|
||||
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
||||
from app.model_downloader.http_client import get_session, parse_content_length
|
||||
|
||||
|
||||
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
file_size: Optional[int]
|
||||
is_hf_downloadable: Optional[bool]
|
||||
|
||||
|
||||
# --- caches -------------------------------------------------------------- #
|
||||
|
||||
|
||||
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
|
||||
# model — never changes for a given URL.
|
||||
_is_gated_cache: dict[str, bool] = {}
|
||||
|
||||
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
|
||||
# was attempted and produced no answer. **Only populated when we knew
|
||||
# the URL was downloadable to us at probe time** — so gated-without-
|
||||
# access never lands a ``None`` here that we'd be stuck with after login.
|
||||
_file_size_cache: dict[str, Optional[int]] = {}
|
||||
|
||||
# Per-URL locks for single-flight probes — when multiple polls arrive
|
||||
# in the same tick for the same URL, exactly one of them runs the HF
|
||||
# call and the others wait on the result.
|
||||
_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _lock_for(url: str) -> asyncio.Lock:
|
||||
lock = _locks.get(url)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_locks[url] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def clear_caches_for_tests() -> None:
|
||||
"""Test-only: drop everything."""
|
||||
_is_gated_cache.clear()
|
||||
_file_size_cache.clear()
|
||||
_locks.clear()
|
||||
|
||||
|
||||
# --- public entrypoint --------------------------------------------------- #
|
||||
|
||||
|
||||
async def probe_url(url: str) -> ProbeResult:
|
||||
"""Return downloadability + size for one URL, using caches where safe."""
|
||||
if not is_hf_url(url):
|
||||
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
|
||||
# Size we still cache so we don't HEAD on every poll.
|
||||
size = await _get_or_probe_size(url, token=None)
|
||||
return ProbeResult(file_size=size, is_hf_downloadable=None)
|
||||
|
||||
repo_id = repo_id_from_url(url)
|
||||
if repo_id is None:
|
||||
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||
|
||||
# Determine intrinsic gating once.
|
||||
gated = await _resolve_is_gated(url, repo_id)
|
||||
if gated is None:
|
||||
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||
|
||||
# Compute current-token downloadability per call.
|
||||
tok = HF_AUTH_STORE.get_token_sync()
|
||||
token_str: Optional[str] = tok.access_token if tok else None
|
||||
if not gated:
|
||||
is_hf_downloadable: Optional[bool] = True
|
||||
else:
|
||||
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
|
||||
|
||||
if is_hf_downloadable is True:
|
||||
size = await _get_or_probe_size(url, token=token_str)
|
||||
else:
|
||||
# Skip the HEAD entirely — would 401 and we'd be stuck with
|
||||
# cached None that survives a later login.
|
||||
size = None
|
||||
|
||||
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
|
||||
|
||||
|
||||
# --- gated/auth probes --------------------------------------------------- #
|
||||
|
||||
|
||||
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
|
||||
"""Decide once whether ``repo_id`` is a gated repo."""
|
||||
cached = _is_gated_cache.get(url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
async with _lock_for(url):
|
||||
cached = _is_gated_cache.get(url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
await asyncio.to_thread(_auth_check_sync, repo_id, None)
|
||||
_is_gated_cache[url] = False
|
||||
return False
|
||||
except GatedRepoError:
|
||||
_is_gated_cache[url] = True
|
||||
return True
|
||||
except RepositoryNotFoundError:
|
||||
# Repo doesn't exist publicly. Treat as gated — we can't
|
||||
# serve it without auth, and an authenticated check might
|
||||
# still succeed if it's a private repo the user can see.
|
||||
_is_gated_cache[url] = True
|
||||
return True
|
||||
except (HfHubHTTPError, Exception) as e:
|
||||
logging.debug(
|
||||
"[hf_auth] is_gated probe failed for %s (will retry): %s",
|
||||
repo_id, e,
|
||||
)
|
||||
return None # don't cache; retry next call
|
||||
|
||||
|
||||
async def _auth_check_with_token(
|
||||
repo_id: str, token: Optional[str]
|
||||
) -> Optional[bool]:
|
||||
"""Auth-check with the supplied token. True/False/None per outcome."""
|
||||
try:
|
||||
await asyncio.to_thread(_auth_check_sync, repo_id, token)
|
||||
return True
|
||||
except GatedRepoError:
|
||||
return False
|
||||
except RepositoryNotFoundError:
|
||||
return False
|
||||
except HfHubHTTPError as e:
|
||||
# 401/403 covers org-SSO-required, revoked tokens, and similar —
|
||||
# all of which mean "can't fetch right now" from the user's POV.
|
||||
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||
if status in (401, 403):
|
||||
return False
|
||||
logging.debug(
|
||||
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
|
||||
return None
|
||||
|
||||
|
||||
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
|
||||
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
|
||||
HfApi().auth_check(repo_id, token=token)
|
||||
|
||||
|
||||
# --- size probe ---------------------------------------------------------- #
|
||||
|
||||
|
||||
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
|
||||
"""Return the cached size or HEAD the URL once and cache the result."""
|
||||
if url in _file_size_cache:
|
||||
return _file_size_cache[url]
|
||||
|
||||
async with _lock_for(url):
|
||||
if url in _file_size_cache:
|
||||
return _file_size_cache[url]
|
||||
size = await _probe_size_once(url, token=token)
|
||||
_file_size_cache[url] = size
|
||||
return size
|
||||
|
||||
|
||||
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
|
||||
"""HEAD the URL and return the file size in bytes, or None on failure.
|
||||
|
||||
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
|
||||
The real file size lives in the non-standard ``X-Linked-Size`` header
|
||||
on that 302 response (``Content-Length`` is the redirect-body length).
|
||||
Disabling redirect-follow lets us read either header on the same
|
||||
response:
|
||||
|
||||
- LFS files: 302 + ``X-Linked-Size``
|
||||
- Small/non-LFS files: 200 + ``Content-Length``
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||
try:
|
||||
session = await get_session()
|
||||
async with session.head(
|
||||
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
|
||||
) as resp:
|
||||
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
|
||||
if linked is not None:
|
||||
return linked
|
||||
if resp.status == 200:
|
||||
return parse_content_length(resp.headers.get("Content-Length"))
|
||||
return None
|
||||
except (aiohttp.ClientError, TimeoutError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
# Backward-compat shim so consumers that still import the old name keep
|
||||
# building during the refactor; can be removed once routes are updated.
|
||||
MetadataProbeResult = ProbeResult
|
||||
106
app/model_downloader/hf_auth/auth_store.py
Normal file
106
app/model_downloader/hf_auth/auth_store.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""In-memory token cache with lazy disk persistence + refresh.
|
||||
|
||||
Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask
|
||||
``get_valid_token()``; the store transparently refreshes from disk
|
||||
on first use, refreshes via the OAuth refresh_token if the cached
|
||||
access_token is expired, and returns ``None`` if neither path works.
|
||||
|
||||
The refresh path imports ``oauth.refresh_access_token`` lazily to
|
||||
avoid an import cycle (oauth needs the store to save tokens it
|
||||
acquires).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.model_downloader.hf_auth.token_store import (
|
||||
Token,
|
||||
delete_token,
|
||||
load_token,
|
||||
save_token,
|
||||
)
|
||||
|
||||
|
||||
class HfAuthStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._token: Optional[Token] = None
|
||||
self._loaded_from_disk = False
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Read the disk token into memory on first access."""
|
||||
if self._loaded_from_disk:
|
||||
return
|
||||
with self._lock:
|
||||
if self._loaded_from_disk:
|
||||
return
|
||||
self._token = load_token()
|
||||
self._loaded_from_disk = True
|
||||
|
||||
def has_token(self) -> bool:
|
||||
"""Cheap check: is there any token in memory?
|
||||
|
||||
Does not attempt refresh; an expired-but-refreshable token still
|
||||
counts as "logged in" from the user's perspective.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
return self._token is not None
|
||||
|
||||
def set_token(self, token: Token) -> None:
|
||||
"""Replace the in-memory token and persist to disk."""
|
||||
with self._lock:
|
||||
self._token = token
|
||||
self._loaded_from_disk = True
|
||||
save_token(token)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Forget the token in memory and on disk (logout)."""
|
||||
with self._lock:
|
||||
self._token = None
|
||||
self._loaded_from_disk = True
|
||||
delete_token()
|
||||
|
||||
def get_token_sync(self) -> Optional[Token]:
|
||||
"""Return the cached token without refreshing.
|
||||
|
||||
Sync callers (e.g. constructing an Authorization header in a
|
||||
non-async path) use this. They accept an expired token over
|
||||
``None``; HF will simply return 401 and the caller can decide
|
||||
what to do.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
return self._token
|
||||
|
||||
async def get_valid_token(self) -> Optional[Token]:
|
||||
"""Return a fresh token, refreshing via OAuth if necessary.
|
||||
|
||||
Returns ``None`` if there's no cached token at all, or if the
|
||||
cached token is expired and refresh failed. Callers should
|
||||
treat that as "user is not logged in".
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
tok = self._token
|
||||
if tok is None:
|
||||
return None
|
||||
if tok.is_valid():
|
||||
return tok
|
||||
if not tok.refresh_token:
|
||||
return None
|
||||
|
||||
# Lazy import to avoid the oauth ↔ store import cycle.
|
||||
from app.model_downloader.hf_auth.oauth import refresh_access_token
|
||||
|
||||
try:
|
||||
refreshed = await refresh_access_token(tok.refresh_token)
|
||||
except Exception as e:
|
||||
logging.warning("[hf_auth] token refresh failed: %s", e)
|
||||
return None
|
||||
self.set_token(refreshed)
|
||||
return refreshed
|
||||
|
||||
|
||||
HF_AUTH_STORE = HfAuthStore()
|
||||
55
app/model_downloader/hf_auth/eligibility.py
Normal file
55
app/model_downloader/hf_auth/eligibility.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Whether this deployment is allowed to do interactive HF OAuth.
|
||||
|
||||
We only let the server hold a HuggingFace token under a strict trust
|
||||
assumption: this is a *single tenant local* install. Concretely:
|
||||
|
||||
- The server is bound to a loopback address. SSH tunneling /
|
||||
reverse-proxies can defeat this, but it's the strongest signal
|
||||
we have without an authentication system.
|
||||
- ``--multi-user`` is off. A shared token used implicitly by multiple
|
||||
declared users would be a footgun — one user's gated downloads
|
||||
would silently authenticate as another.
|
||||
|
||||
Anything else and the frontend hides the HF login UI entirely; gated
|
||||
models continue to show the "acquire it manually" message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
def _is_loopback(host: str | None) -> bool:
|
||||
"""Duplicates ``server.is_loopback`` (small, no shared module yet).
|
||||
|
||||
Resolves a host or IP literal to whether it lives on the loopback
|
||||
interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for
|
||||
``0.0.0.0`` / ``::`` because those are bind-all wildcards, not
|
||||
loopback.
|
||||
"""
|
||||
if host is None:
|
||||
return False
|
||||
try:
|
||||
return ipaddress.ip_address(host).is_loopback
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
loopback = False
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||
for _family, _, _, _, sockaddr in r:
|
||||
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||
return loopback
|
||||
loopback = True
|
||||
except socket.gaierror:
|
||||
pass
|
||||
return loopback
|
||||
|
||||
|
||||
def is_hf_auth_eligible() -> bool:
|
||||
"""True iff this deployment may surface the HF OAuth flow."""
|
||||
return _is_loopback(args.listen) and not args.multi_user
|
||||
277
app/model_downloader/hf_auth/oauth.py
Normal file
277
app/model_downloader/hf_auth/oauth.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
|
||||
|
||||
Wired so that ``POST /api/hf-auth-login-start`` can:
|
||||
1. Generate state + PKCE verifier/challenge in this process.
|
||||
2. Spin up a short-lived loopback HTTP server at port 41954 to
|
||||
receive the redirect callback from HF.
|
||||
3. Return the ``authorize_url`` for the frontend to open in a new tab.
|
||||
|
||||
After the user grants consent on huggingface.co, HF redirects to the
|
||||
local callback URL with ``code`` and ``state``. The callback server
|
||||
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
|
||||
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
|
||||
itself down.
|
||||
|
||||
Before this can be exercised end-to-end a maintainer must register a
|
||||
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
|
||||
below. See the comment above the constant for the exact steps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_auth.token_store import Token
|
||||
from app.model_downloader.http_client import ssl_context
|
||||
|
||||
|
||||
# --- HF OAuth app registration -------------------------------------------- #
|
||||
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
|
||||
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
|
||||
# under a Comfy-Org-controlled HF account and substitute its client_id here.
|
||||
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
|
||||
# ("HuggingFace OAuth app setup" section). Short version:
|
||||
# 1. huggingface.co → Settings → Connected Apps → "Create app"
|
||||
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
|
||||
# ``gated-repos`` (Repository Access). Leave everything else off.
|
||||
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
|
||||
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
|
||||
# change ``CALLBACK_PORT``.
|
||||
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
|
||||
# The client_id is not a secret (it travels through the user's browser in
|
||||
# plaintext); HF's "Public app" type means there's no client secret to
|
||||
# manage — PKCE replaces it.
|
||||
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
|
||||
|
||||
CALLBACK_HOST = "127.0.0.1"
|
||||
CALLBACK_PORT = 41954
|
||||
CALLBACK_PATH = "/api/auth/huggingface/callback"
|
||||
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
|
||||
|
||||
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
|
||||
TOKEN_URL = "https://huggingface.co/oauth/token"
|
||||
# Minimal scope set for the feature:
|
||||
# - openid : required by HF when the app uses OIDC at all
|
||||
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
|
||||
# settings UI; cosmetic but expected
|
||||
# - gated-repos : grants the token enough to call ``auth_check`` and
|
||||
# download files from public gated repos the user has
|
||||
# accepted the license for. The wider ``read-repos`` scope
|
||||
# would also work (it includes ``gated-repos``) but it
|
||||
# additionally grants private-repo read access, which we
|
||||
# don't need and which makes the consent screen scarier
|
||||
# for the user.
|
||||
SCOPE = "openid profile gated-repos"
|
||||
|
||||
# Maximum time the callback server stays up waiting for the user to
|
||||
# complete consent on huggingface.co. Past this, the port closes and
|
||||
# the user has to click "Log in" again.
|
||||
CALLBACK_TIMEOUT_SECS = 300
|
||||
|
||||
|
||||
# Process-wide lock so two simultaneous /api/hf-auth-login-start
|
||||
# requests don't fight over port CALLBACK_PORT.
|
||||
_OAUTH_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class OAuthInProgressError(Exception):
|
||||
"""Another OAuth attempt is already running."""
|
||||
|
||||
|
||||
class OAuthCallbackError(Exception):
|
||||
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
|
||||
|
||||
|
||||
# --- PKCE primitives ------------------------------------------------------ #
|
||||
|
||||
|
||||
def _make_pkce() -> tuple[str, str, str]:
|
||||
"""Return ``(verifier, challenge, state)``.
|
||||
|
||||
Verifier never leaves this process. Challenge and state travel
|
||||
through the user's browser. State is checked on the callback to
|
||||
prevent a malicious cross-origin redirect from injecting a token.
|
||||
"""
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
||||
.rstrip(b"=")
|
||||
.decode("ascii")
|
||||
)
|
||||
state = secrets.token_urlsafe(32)
|
||||
return verifier, challenge, state
|
||||
|
||||
|
||||
def _build_authorize_url(challenge: str, state: str) -> str:
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
"client_id": HF_CLIENT_ID,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"response_type": "code",
|
||||
"scope": SCOPE,
|
||||
"state": state,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
return f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||
|
||||
|
||||
# --- Token exchange ------------------------------------------------------- #
|
||||
|
||||
|
||||
async def _exchange_code(code: str, verifier: str) -> Token:
|
||||
"""Trade the authorization code for an access+refresh token pair."""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"client_id": HF_CLIENT_ID,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
|
||||
resp.raise_for_status()
|
||||
body = await resp.json()
|
||||
return Token(
|
||||
access_token=body["access_token"],
|
||||
refresh_token=body.get("refresh_token"),
|
||||
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
||||
scope=body.get("scope", SCOPE),
|
||||
)
|
||||
|
||||
|
||||
async def refresh_access_token(refresh_token: str) -> Token:
|
||||
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": HF_CLIENT_ID,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
|
||||
resp.raise_for_status()
|
||||
body = await resp.json()
|
||||
return Token(
|
||||
access_token=body["access_token"],
|
||||
# If HF doesn't rotate refresh tokens, keep using the existing one.
|
||||
refresh_token=body.get("refresh_token", refresh_token),
|
||||
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
||||
scope=body.get("scope", SCOPE),
|
||||
)
|
||||
|
||||
|
||||
# --- Callback server ------------------------------------------------------ #
|
||||
|
||||
|
||||
async def start_login_flow() -> str:
|
||||
"""Begin one OAuth attempt: spawn the callback server, return the URL.
|
||||
|
||||
Returns the URL the frontend should open in a new tab. Raises
|
||||
``OAuthInProgressError`` if another attempt is already running.
|
||||
The callback server runs in the background until the user
|
||||
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
|
||||
either way the lock + port are released afterward.
|
||||
"""
|
||||
if not _OAUTH_LOCK.acquire(blocking=False):
|
||||
raise OAuthInProgressError()
|
||||
|
||||
verifier, challenge, state = _make_pkce()
|
||||
authorize_url = _build_authorize_url(challenge, state)
|
||||
|
||||
# Fire the callback server on the running loop and return.
|
||||
asyncio.create_task(_run_callback_server(verifier, state))
|
||||
return authorize_url
|
||||
|
||||
|
||||
async def _run_callback_server(verifier: str, expected_state: str) -> None:
|
||||
"""Listen for HF's redirect once, capture the token, then shut down."""
|
||||
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
try:
|
||||
if request.query.get("state") != expected_state:
|
||||
return web.Response(status=400, text="state mismatch")
|
||||
err = request.query.get("error")
|
||||
if err:
|
||||
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
|
||||
return web.Response(status=400, text=f"OAuth error: {err}")
|
||||
code = request.query.get("code")
|
||||
if not code:
|
||||
return web.Response(status=400, text="missing code")
|
||||
tok = await _exchange_code(code, verifier)
|
||||
if not received.done():
|
||||
received.set_result(tok)
|
||||
return web.Response(
|
||||
content_type="text/html",
|
||||
text=(
|
||||
"<html><body style='font-family:sans-serif;padding:40px'>"
|
||||
"<h2>HuggingFace login successful</h2>"
|
||||
"<p>You can close this tab and return to ComfyUI.</p>"
|
||||
"</body></html>"
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
if not received.done():
|
||||
received.set_exception(exc)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get(CALLBACK_PATH, handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
|
||||
try:
|
||||
await site.start()
|
||||
except OSError as e:
|
||||
# Port already in use (or some other socket-bind failure). Release
|
||||
# the lock so a future attempt has a chance to succeed.
|
||||
logging.warning("[hf_auth] could not bind callback port: %s", e)
|
||||
_OAUTH_LOCK.release()
|
||||
return
|
||||
|
||||
try:
|
||||
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
|
||||
return
|
||||
except OAuthCallbackError as e:
|
||||
logging.warning("[hf_auth] OAuth callback error: %s", e)
|
||||
return
|
||||
except Exception as e:
|
||||
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
|
||||
return
|
||||
else:
|
||||
HF_AUTH_STORE.set_token(token)
|
||||
logging.info("[hf_auth] OAuth login complete")
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
if _OAUTH_LOCK.locked():
|
||||
_OAUTH_LOCK.release()
|
||||
|
||||
|
||||
def is_login_in_progress() -> bool:
|
||||
"""True iff a callback server is currently bound + waiting."""
|
||||
return _OAUTH_LOCK.locked()
|
||||
|
||||
|
||||
# Re-export for callers that only want the URL builder (e.g. tests).
|
||||
__all__ = [
|
||||
"start_login_flow",
|
||||
"refresh_access_token",
|
||||
"is_login_in_progress",
|
||||
"OAuthInProgressError",
|
||||
"CALLBACK_TIMEOUT_SECS",
|
||||
]
|
||||
89
app/model_downloader/hf_auth/token_store.py
Normal file
89
app/model_downloader/hf_auth/token_store.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""On-disk persistence for the HuggingFace OAuth token.
|
||||
|
||||
The token shape mirrors what HF returns on the token exchange: an
|
||||
``access_token``, an optional ``refresh_token``, the absolute epoch at
|
||||
which the access token expires, and the granted scope. We persist
|
||||
this so logging in once survives ComfyUI restarts; the file is mode
|
||||
``0600`` so only the OS user can read it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Treat a token as expired this many seconds before its server-reported
|
||||
# ``expires_at`` so we don't try to use a token mid-request only for it
|
||||
# to flip stale between auth_check and the actual GET.
|
||||
EXPIRY_BUFFER_SECS = 60
|
||||
|
||||
TOKEN_FILENAME = "hf_auth_token.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
"""One OAuth token + the metadata we need to use it."""
|
||||
access_token: str
|
||||
refresh_token: Optional[str]
|
||||
expires_at: float # absolute epoch seconds
|
||||
scope: str = ""
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""True iff we can use this token right now."""
|
||||
return (
|
||||
bool(self.access_token)
|
||||
and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS)
|
||||
)
|
||||
|
||||
|
||||
def _token_path() -> str:
|
||||
base = folder_paths.get_user_directory()
|
||||
return os.path.join(base, TOKEN_FILENAME)
|
||||
|
||||
|
||||
def load_token() -> Optional[Token]:
|
||||
"""Read the persisted token, returning ``None`` if absent or corrupt."""
|
||||
path = _token_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return Token(**data)
|
||||
except (OSError, json.JSONDecodeError, TypeError) as e:
|
||||
logging.warning("[hf_auth] could not load token at %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def save_token(token: Token) -> None:
|
||||
"""Atomically write the token with 0600 permissions."""
|
||||
path = _token_path()
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
tmp = path + ".tmp"
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(token), f)
|
||||
os.replace(tmp, path)
|
||||
try:
|
||||
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
except OSError as e:
|
||||
# On Windows / weird filesystems chmod may be a no-op; not fatal.
|
||||
logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e)
|
||||
|
||||
|
||||
def delete_token() -> None:
|
||||
"""Remove the persisted token; no-op if it doesn't exist."""
|
||||
path = _token_path()
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning("[hf_auth] could not remove token at %s: %s", path, e)
|
||||
41
app/model_downloader/hf_url.py
Normal file
41
app/model_downloader/hf_url.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Parsers for the ``huggingface.co`` URL shape we accept in workflows.
|
||||
|
||||
The download API accepts URLs of the form
|
||||
``https://huggingface.co/<org>/<repo>/resolve/<rev>/<path/to/file>``.
|
||||
We need to recover ``<org>/<repo>`` (the *repo_id*) from such URLs for
|
||||
``huggingface_hub`` API calls (notably ``HfApi.auth_check``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
_HF_HOST = "huggingface.co"
|
||||
|
||||
|
||||
def is_hf_url(url: str) -> bool:
|
||||
"""Cheap host check — does this URL point at huggingface.co?"""
|
||||
try:
|
||||
return urlparse(url).hostname == _HF_HOST
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def repo_id_from_url(url: str) -> Optional[str]:
|
||||
"""Extract ``<org>/<repo>`` from an HF model file URL.
|
||||
|
||||
Returns ``None`` if the URL isn't on huggingface.co or doesn't look
|
||||
like a model-file URL. The expected shape is
|
||||
``/<org>/<repo>/resolve/<rev>/<path>`` — anything else
|
||||
(datasets, spaces, /tree/, /blob/, …) we treat as out of scope here.
|
||||
"""
|
||||
if not is_hf_url(url):
|
||||
return None
|
||||
parts = urlparse(url).path.lstrip("/").split("/")
|
||||
if len(parts) < 4 or parts[2] != "resolve":
|
||||
return None
|
||||
org, repo = parts[0], parts[1]
|
||||
if not org or not repo:
|
||||
return None
|
||||
return f"{org}/{repo}"
|
||||
63
app/model_downloader/http_client.py
Normal file
63
app/model_downloader/http_client.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Lazy module-level aiohttp ClientSession.
|
||||
|
||||
A single shared session means TLS handshakes are reused across HEAD probes
|
||||
and the subsequent GETs to the same host (HuggingFace is the dominant
|
||||
case), which is a noticeable speedup on cold connections.
|
||||
|
||||
We deliberately don't close the session at process exit — aiohttp's
|
||||
warning about unclosed sessions is benign at shutdown, and adding atexit
|
||||
plumbing buys nothing because the OS reclaims the sockets anyway. The
|
||||
session lifetime is the lifetime of the Python process.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
|
||||
# Larger per-host pool than aiohttp's default (=100 total / =0 per host)
|
||||
# so concurrent gated probes + a download to the same host don't queue.
|
||||
_CONNECTOR_LIMIT_PER_HOST = 8
|
||||
|
||||
_session: Optional[aiohttp.ClientSession] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def ssl_context() -> ssl.SSLContext:
|
||||
"""TLS context pinned to certifi's CA bundle.
|
||||
aiohttp's default context uses the OS trust store, which isn't wired up
|
||||
on some Python installs (python.org macOS, slim containers) — there TLS
|
||||
to huggingface.co fails with CERTIFICATE_VERIFY_FAILED.
|
||||
"""
|
||||
return ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
|
||||
async def get_session() -> aiohttp.ClientSession:
|
||||
"""Return the shared session, creating it on first call."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit_per_host=_CONNECTOR_LIMIT_PER_HOST,
|
||||
ssl=ssl_context(),
|
||||
)
|
||||
_session = aiohttp.ClientSession(connector=connector)
|
||||
return _session
|
||||
|
||||
|
||||
def parse_content_length(value: Optional[str]) -> Optional[int]:
|
||||
"""Parse a byte-count header value, or None if absent/malformed/negative."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
n = int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return n if n >= 0 else None
|
||||
93
app/model_downloader/paths.py
Normal file
93
app/model_downloader/paths.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Path resolution for model downloads.
|
||||
|
||||
Model identifiers used across the download API are *relative destination
|
||||
paths* of the form ``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``).
|
||||
This module turns one of those identifiers into an absolute on-disk path
|
||||
under one of ComfyUI's registered model folders, while rejecting unknown
|
||||
folders, path traversal, and other ill-formed inputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Constrain components so a model_id can never escape its target directory.
|
||||
# - directory: a single path segment of safe chars
|
||||
# - filename: a single path segment of safe chars, must end with a model extension
|
||||
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
class InvalidModelId(ValueError):
|
||||
"""Raised when a model_id is syntactically invalid or refers to an
|
||||
unknown model folder."""
|
||||
|
||||
|
||||
def parse_model_id(model_id: str) -> Tuple[str, str]:
|
||||
"""Split ``<directory>/<filename>`` and validate both components.
|
||||
|
||||
Returns ``(directory, filename)``. Raises ``InvalidModelId`` on
|
||||
malformed input. Does NOT touch the filesystem.
|
||||
"""
|
||||
if not isinstance(model_id, str) or "/" not in model_id:
|
||||
raise InvalidModelId(f"model_id must be '<directory>/<filename>', got {model_id!r}")
|
||||
directory, _, filename = model_id.partition("/")
|
||||
if "/" in filename or not directory or not filename:
|
||||
raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}")
|
||||
if not _SEGMENT_RE.match(directory):
|
||||
raise InvalidModelId(f"invalid directory segment {directory!r}")
|
||||
if not _SEGMENT_RE.match(filename):
|
||||
raise InvalidModelId(f"invalid filename segment {filename!r}")
|
||||
if directory not in folder_paths.folder_names_and_paths:
|
||||
raise InvalidModelId(f"unknown model folder {directory!r}")
|
||||
return directory, filename
|
||||
|
||||
|
||||
def resolve_existing(model_id: str) -> Optional[str]:
|
||||
"""Return the absolute path of an installed model, or None if missing.
|
||||
|
||||
Honours ``extra_model_paths.yaml`` transparently via
|
||||
``folder_paths.get_full_path``.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id)
|
||||
return folder_paths.get_full_path(directory, filename)
|
||||
|
||||
|
||||
def resolve_destination(model_id: str) -> Tuple[str, str]:
|
||||
"""Return ``(final_path, tmp_path)`` for a download.
|
||||
|
||||
Downloads land at the first registered path for the model's directory
|
||||
(the "primary" location). The ``.tmp`` sibling is used as the write
|
||||
target and atomically renamed on success.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id)
|
||||
roots = folder_paths.get_folder_paths(directory)
|
||||
if not roots:
|
||||
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
|
||||
root = roots[0]
|
||||
final_path = os.path.join(root, filename)
|
||||
tmp_path = final_path + ".tmp"
|
||||
return final_path, tmp_path
|
||||
|
||||
|
||||
def iter_all_tmp_paths():
|
||||
"""Yield every ``*.tmp`` file under every registered model folder.
|
||||
|
||||
Used at startup to sweep orphans left by crashed/restarted downloads.
|
||||
"""
|
||||
seen_roots: set[str] = set()
|
||||
for directory in folder_paths.folder_names_and_paths.keys():
|
||||
for root in folder_paths.get_folder_paths(directory):
|
||||
if root in seen_roots or not os.path.isdir(root):
|
||||
continue
|
||||
seen_roots.add(root)
|
||||
try:
|
||||
for entry in os.scandir(root):
|
||||
if entry.is_file() and entry.name.endswith(".tmp"):
|
||||
yield entry.path
|
||||
except OSError:
|
||||
# Folder might be unreadable / missing on certain mounts —
|
||||
# not fatal, just skip it.
|
||||
continue
|
||||
@ -98,12 +98,24 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
|
||||
|
||||
|
||||
# Default server capabilities
|
||||
def _hf_auth_eligible_at_startup() -> bool:
|
||||
"""Snapshot eligibility once at feature-flag init time.
|
||||
|
||||
Imports lazily because the flags module loads very early in the
|
||||
server boot sequence — earlier than the model_downloader package.
|
||||
"""
|
||||
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
|
||||
return is_hf_auth_eligible()
|
||||
|
||||
|
||||
_CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
"server_side_model_downloads": True,
|
||||
"hf_auth_eligible": _hf_auth_eligible_at_startup(),
|
||||
}
|
||||
|
||||
# CLI-provided flags cannot overwrite core flags
|
||||
|
||||
1273
docs/server-side-model-downloads-handover.html
Normal file
1273
docs/server-side-model-downloads-handover.html
Normal file
File diff suppressed because it is too large
Load Diff
316
openapi.yaml
316
openapi.yaml
@ -188,6 +188,49 @@ components:
|
||||
- id
|
||||
- updated_at
|
||||
type: object
|
||||
AvailabilityStatusRequest:
|
||||
description: |
|
||||
Models to query — each entry is `model_id → URL`. The URL lets
|
||||
the server compute file_size + is_hf_downloadable on the same
|
||||
request, eliminating the need for a separate metadata endpoint.
|
||||
properties:
|
||||
models:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: model_id → URL declared in the workflow.
|
||||
type: object
|
||||
required:
|
||||
- models
|
||||
type: object
|
||||
AvailabilityStatusResponse:
|
||||
description: Per-model state + metadata + HF auth snapshot.
|
||||
properties:
|
||||
hf_auth:
|
||||
$ref: '#/components/schemas/HfAuthStatus'
|
||||
models:
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/ModelStatusEntry'
|
||||
type: object
|
||||
required:
|
||||
- models
|
||||
- hf_auth
|
||||
type: object
|
||||
CancelDownloadSessionRequest:
|
||||
description: Request to cancel an in-flight download for a given model_id.
|
||||
properties:
|
||||
model_id:
|
||||
type: string
|
||||
required:
|
||||
- model_id
|
||||
type: object
|
||||
CancelDownloadSessionResponse:
|
||||
description: Result of a cancellation request.
|
||||
properties:
|
||||
cancelled:
|
||||
type: boolean
|
||||
required:
|
||||
- cancelled
|
||||
type: object
|
||||
CreateWorkflowRequest:
|
||||
description: Request body for creating a new saved workflow.
|
||||
properties:
|
||||
@ -230,6 +273,51 @@ components:
|
||||
- base_version
|
||||
- workflow_json
|
||||
type: object
|
||||
DownloadModelsRequest:
|
||||
description: Map of model_id → URL of files to fetch into the model folders.
|
||||
properties:
|
||||
models:
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: model_id → URL of models to download.
|
||||
type: object
|
||||
required:
|
||||
- models
|
||||
type: object
|
||||
DownloadModelsResponse:
|
||||
description: Acknowledgement that downloads have been scheduled.
|
||||
properties:
|
||||
accepted:
|
||||
description: Always true; the request was scheduled.
|
||||
type: boolean
|
||||
scheduled:
|
||||
description: The list of model_ids whose downloads are now in-flight.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
required:
|
||||
- accepted
|
||||
- scheduled
|
||||
type: object
|
||||
DownloadProgress:
|
||||
description: In-flight download progress; embedded in ModelStatusEntry.
|
||||
properties:
|
||||
bytes_downloaded:
|
||||
format: int64
|
||||
type: integer
|
||||
progress:
|
||||
description: Fraction in [0,1]; null until total_bytes is known.
|
||||
format: float
|
||||
nullable: true
|
||||
type: number
|
||||
total_bytes:
|
||||
description: Content-Length when supplied by the source.
|
||||
format: int64
|
||||
nullable: true
|
||||
type: integer
|
||||
required:
|
||||
- bytes_downloaded
|
||||
type: object
|
||||
ErrorResponse:
|
||||
description: Standard error response with a machine-readable code and human-readable message.
|
||||
properties:
|
||||
@ -394,6 +482,46 @@ components:
|
||||
- name
|
||||
- info
|
||||
type: object
|
||||
HfAuthLoginStartResponse:
|
||||
description: URL the frontend should open in a new tab to complete login.
|
||||
properties:
|
||||
authorize_url:
|
||||
type: string
|
||||
required:
|
||||
- authorize_url
|
||||
type: object
|
||||
HfAuthLogoutResponse:
|
||||
description: Result of the logout call (always logged_out = true).
|
||||
properties:
|
||||
logged_out:
|
||||
type: boolean
|
||||
required:
|
||||
- logged_out
|
||||
type: object
|
||||
HfAuthStatus:
|
||||
description: Inline snapshot of the server's HuggingFace OAuth state.
|
||||
properties:
|
||||
eligible:
|
||||
description: True iff this deployment can surface interactive HF login.
|
||||
type: boolean
|
||||
token_available:
|
||||
description: True iff a token (possibly expired but refreshable) is stored.
|
||||
type: boolean
|
||||
required:
|
||||
- token_available
|
||||
- eligible
|
||||
type: object
|
||||
HfAuthTokenStatusResponse:
|
||||
description: Whether the server holds an HF OAuth token + resolved username.
|
||||
properties:
|
||||
token_available:
|
||||
type: boolean
|
||||
username:
|
||||
nullable: true
|
||||
type: string
|
||||
required:
|
||||
- token_available
|
||||
type: object
|
||||
HistoryDetailEntry:
|
||||
description: History entry with full prompt data
|
||||
properties:
|
||||
@ -798,6 +926,40 @@ components:
|
||||
- name
|
||||
- folders
|
||||
type: object
|
||||
ModelStatusEntry:
|
||||
description: Everything the UI needs to render one row of the model.
|
||||
properties:
|
||||
file_size:
|
||||
description: Bytes, when known. Cached server-side per URL.
|
||||
format: int64
|
||||
nullable: true
|
||||
type: integer
|
||||
is_hf_downloadable:
|
||||
description: |
|
||||
HuggingFace-only signal. True if the server can fetch this
|
||||
URL with its current auth state (public, or gated-with-access).
|
||||
False if gated and lacking access. Null for non-HF URLs and
|
||||
for HF URLs whose probe failed entirely.
|
||||
nullable: true
|
||||
type: boolean
|
||||
progress:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/DownloadProgress'
|
||||
description: Present when `state == downloading`.
|
||||
nullable: true
|
||||
state:
|
||||
description: |
|
||||
`available` — file is on disk.
|
||||
`missing` — not on disk and no download in flight.
|
||||
`downloading` — server is currently fetching the file.
|
||||
enum:
|
||||
- available
|
||||
- missing
|
||||
- downloading
|
||||
type: string
|
||||
required:
|
||||
- state
|
||||
type: object
|
||||
NodeInfo:
|
||||
description: Metadata describing a single ComfyUI node type and its inputs/outputs.
|
||||
properties:
|
||||
@ -2338,6 +2500,60 @@ paths:
|
||||
summary: Get tag histogram for filtered assets
|
||||
tags:
|
||||
- file
|
||||
/api/cancel-model-download-session:
|
||||
post:
|
||||
description: |
|
||||
Cancels the download session for the given model_id. The worker
|
||||
observes the cancellation between chunks, removes its partial `.tmp`
|
||||
file, and exits without writing the destination path.
|
||||
operationId: postCancelModelDownloadSession
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CancelDownloadSessionRequest'
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CancelDownloadSessionResponse'
|
||||
description: Session cancelled.
|
||||
"404":
|
||||
description: No active download for that model_id.
|
||||
summary: Cancel an in-flight server-side model download
|
||||
tags:
|
||||
- model
|
||||
/api/download-models:
|
||||
post:
|
||||
description: |
|
||||
Schedules downloads for every model_id in the request map. Returns
|
||||
immediately after validation; progress is observed via
|
||||
`/api/models-availability-status`. Fails atomically if any model
|
||||
is already on disk, already downloading, gated, or has a URL that
|
||||
is not on the server's allowlist (HuggingFace, Civitai, localhost).
|
||||
operationId: postDownloadModels
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadModelsRequest'
|
||||
required: true
|
||||
responses:
|
||||
"202":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/DownloadModelsResponse'
|
||||
description: Downloads accepted and scheduled.
|
||||
"400":
|
||||
description: One of the requested models is invalid, gated, or has a non-allowed URL.
|
||||
"409":
|
||||
description: One of the requested models is already on disk or downloading.
|
||||
summary: Start a server-side download of one or more models
|
||||
tags:
|
||||
- model
|
||||
/api/embeddings:
|
||||
get:
|
||||
description: Returns the list of text-encoder embeddings available on disk.
|
||||
@ -2639,6 +2855,66 @@ paths:
|
||||
summary: Get a specific subgraph blueprint
|
||||
tags:
|
||||
- workflow
|
||||
/api/hf-auth-login-start:
|
||||
post:
|
||||
description: |
|
||||
Spawns a short-lived loopback callback server (port 41954) and
|
||||
returns the URL the frontend should open in a new tab. After the
|
||||
user grants consent, HF redirects back to the callback URL with
|
||||
an authorization code; the server exchanges that for a token and
|
||||
persists it. Subsequent `/api/hf-auth-token-status` calls will
|
||||
return `token_available: true`. Rejected with 403 if the
|
||||
deployment is not eligible (not loopback or in --multi-user mode);
|
||||
409 if another login attempt is already in progress.
|
||||
operationId: postHfAuthLoginStart
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HfAuthLoginStartResponse'
|
||||
description: Login flow started; `authorize_url` is ready.
|
||||
"403":
|
||||
description: Deployment is not eligible for interactive HF login.
|
||||
"409":
|
||||
description: Another login attempt is already in progress.
|
||||
summary: Begin a HuggingFace OAuth login flow
|
||||
tags:
|
||||
- model
|
||||
/api/hf-auth-logout:
|
||||
post:
|
||||
description: Clears the in-memory cache and removes the on-disk token file.
|
||||
operationId: postHfAuthLogout
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HfAuthLogoutResponse'
|
||||
description: Logged out (idempotent — succeeds even if no token was held).
|
||||
summary: Drop the stored HuggingFace OAuth token
|
||||
tags:
|
||||
- model
|
||||
/api/hf-auth-token-status:
|
||||
get:
|
||||
description: |
|
||||
Returns `token_available: true` when the server has a token
|
||||
in memory (or on disk) for HuggingFace, irrespective of whether
|
||||
the access_token is currently fresh — an expired one with a
|
||||
refresh_token still counts as "logged in" because we'll refresh
|
||||
transparently on next use. If a username is resolvable via
|
||||
`HfApi.whoami` we return that too, for the settings UI.
|
||||
operationId: getHfAuthTokenStatus
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HfAuthTokenStatusResponse'
|
||||
description: Token status.
|
||||
summary: Whether the server holds a usable HuggingFace OAuth token
|
||||
tags:
|
||||
- model
|
||||
/api/history:
|
||||
post:
|
||||
deprecated: true
|
||||
@ -3141,6 +3417,44 @@ paths:
|
||||
summary: Cancel multiple jobs
|
||||
tags:
|
||||
- workflow
|
||||
/api/models-availability-status:
|
||||
post:
|
||||
description: |
|
||||
Given a map of `{model_id: url}` (model_id is
|
||||
`<directory>/<filename>`), returns per-id state plus the
|
||||
metadata the UI needs to render the row:
|
||||
|
||||
- `state` — one of `available` / `missing` / `downloading`
|
||||
- `progress` — embedded when `state == downloading`
|
||||
- `file_size` — bytes (when known)
|
||||
- `is_hf_downloadable` — for HF URLs only: true if the
|
||||
server can currently fetch the file with its stored auth
|
||||
state, false if gated and lacking access, null otherwise
|
||||
|
||||
Designed for 1 Hz polling. `file_size` and the intrinsic
|
||||
"is this model gated" check are cached server-side per URL;
|
||||
`is_hf_downloadable` is recomputed per call so license
|
||||
acceptance and login/logout transitions show up within one
|
||||
poll interval without any client-side cache plumbing.
|
||||
operationId: postModelsAvailabilityStatus
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AvailabilityStatusRequest'
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AvailabilityStatusResponse'
|
||||
description: Per-model status and metadata.
|
||||
"400":
|
||||
description: Malformed request body.
|
||||
summary: Unified per-model status + metadata for the polling UI
|
||||
tags:
|
||||
- model
|
||||
/api/node_replacements:
|
||||
get:
|
||||
description: |
|
||||
@ -5087,3 +5401,5 @@ tags:
|
||||
name: queue
|
||||
- description: Job lifecycle queries
|
||||
name: job
|
||||
- description: Server-side model availability and downloads
|
||||
name: model
|
||||
|
||||
@ -9,6 +9,7 @@ numpy>=1.25.0
|
||||
einops
|
||||
transformers>=4.50.3
|
||||
tokenizers>=0.13.3
|
||||
huggingface_hub
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
aiohttp>=3.11.8
|
||||
|
||||
@ -47,6 +47,7 @@ from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@ -256,6 +257,7 @@ class PromptServer():
|
||||
else:
|
||||
register_assets_routes(self.app)
|
||||
asset_seeder.disable()
|
||||
register_model_downloader_routes(self.app)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
|
||||
567
tests-unit/app_test/hf_auth_test.py
Normal file
567
tests-unit/app_test/hf_auth_test.py
Normal file
@ -0,0 +1,567 @@
|
||||
"""Unit tests for the HuggingFace auth subsystem.
|
||||
|
||||
Covers:
|
||||
- token store: save/load roundtrip, chmod 0600, atomic write, delete
|
||||
- eligibility under various CLI-arg combinations
|
||||
- URL parsing (huggingface.co host detection + repo_id extraction)
|
||||
- HF-aware gated_detection.probe_url (mocked auth_check)
|
||||
- HF auth routes (token status, login start with eligibility gate, logout)
|
||||
- PKCE primitives + authorize URL shape
|
||||
|
||||
The OAuth callback server itself isn't exercised end-to-end here — that
|
||||
requires a real HF server. We test the components (state checking,
|
||||
URL building, code-exchange request shape) instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from app.model_downloader.api.routes import register_routes
|
||||
from app.model_downloader.hf_auth import oauth
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore
|
||||
from app.model_downloader.hf_auth.token_store import (
|
||||
EXPIRY_BUFFER_SECS,
|
||||
Token,
|
||||
delete_token,
|
||||
load_token,
|
||||
save_token,
|
||||
)
|
||||
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Fixtures
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_user_dir(tmp_path):
|
||||
"""Redirect ``folder_paths.get_user_directory`` so the token file
|
||||
lands in an isolated tmp_path instead of the real user dir."""
|
||||
user_dir = tmp_path / "user"
|
||||
user_dir.mkdir()
|
||||
with patch("folder_paths.get_user_directory", return_value=str(user_dir)):
|
||||
yield user_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_auth_store():
|
||||
"""Wipe singleton state between tests: auth + probe caches."""
|
||||
from app.model_downloader import gated_detection
|
||||
|
||||
HF_AUTH_STORE._token = None
|
||||
HF_AUTH_STORE._loaded_from_disk = False
|
||||
gated_detection.clear_caches_for_tests()
|
||||
yield HF_AUTH_STORE
|
||||
HF_AUTH_STORE._token = None
|
||||
HF_AUTH_STORE._loaded_from_disk = False
|
||||
gated_detection.clear_caches_for_tests()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(patched_user_dir, fresh_auth_store):
|
||||
app = web.Application()
|
||||
register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# URL parsing
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_is_hf_url_recognises_huggingface_co():
|
||||
assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors")
|
||||
assert is_hf_url("https://huggingface.co/abc")
|
||||
assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors")
|
||||
assert not is_hf_url("https://civitai.com/x.safetensors")
|
||||
|
||||
|
||||
def test_repo_id_from_url_extracts_org_and_repo():
|
||||
url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors"
|
||||
assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR"
|
||||
|
||||
|
||||
def test_repo_id_from_url_handles_nested_path():
|
||||
url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors"
|
||||
assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3"
|
||||
|
||||
|
||||
def test_repo_id_from_url_returns_none_for_non_hf():
|
||||
assert repo_id_from_url("https://civitai.com/x.safetensors") is None
|
||||
|
||||
|
||||
def test_repo_id_from_url_returns_none_for_non_resolve_paths():
|
||||
assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None
|
||||
assert repo_id_from_url("https://huggingface.co/org") is None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Token store
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_token_store_roundtrip(patched_user_dir):
|
||||
tok = Token(
|
||||
access_token="hf_abc",
|
||||
refresh_token="rf_def",
|
||||
expires_at=9999999999.0,
|
||||
scope="openid profile",
|
||||
)
|
||||
save_token(tok)
|
||||
loaded = load_token()
|
||||
assert loaded == tok
|
||||
|
||||
|
||||
def test_token_store_writes_0600(patched_user_dir):
|
||||
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
|
||||
save_token(tok)
|
||||
path = os.path.join(patched_user_dir, "hf_auth_token.json")
|
||||
mode = stat.S_IMODE(os.stat(path).st_mode)
|
||||
# On Windows we silently no-op chmod; allow either the intended
|
||||
# mode or whatever umask the OS gave us.
|
||||
if os.name == "posix":
|
||||
assert mode == 0o600
|
||||
|
||||
|
||||
def test_token_store_delete_removes_file(patched_user_dir):
|
||||
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
|
||||
save_token(tok)
|
||||
delete_token()
|
||||
path = os.path.join(patched_user_dir, "hf_auth_token.json")
|
||||
assert not os.path.exists(path)
|
||||
# Idempotent: second delete is fine.
|
||||
delete_token()
|
||||
|
||||
|
||||
def test_token_store_load_returns_none_for_missing_file(patched_user_dir):
|
||||
assert load_token() is None
|
||||
|
||||
|
||||
def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir):
|
||||
path = os.path.join(patched_user_dir, "hf_auth_token.json")
|
||||
with open(path, "w") as f:
|
||||
f.write("not json {")
|
||||
assert load_token() is None
|
||||
|
||||
|
||||
def test_token_is_valid_uses_buffer(patched_user_dir):
|
||||
import time
|
||||
|
||||
fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
|
||||
nearly_expired = Token(
|
||||
access_token="x",
|
||||
refresh_token=None,
|
||||
expires_at=time.time() + EXPIRY_BUFFER_SECS - 1,
|
||||
)
|
||||
assert fresh.is_valid()
|
||||
assert not nearly_expired.is_valid()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Auth store
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_auth_store_loads_lazily(patched_user_dir):
|
||||
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||
save_token(tok)
|
||||
store = HfAuthStore()
|
||||
assert store.has_token()
|
||||
assert store.get_token_sync() == tok
|
||||
|
||||
|
||||
def test_auth_store_set_persists(patched_user_dir):
|
||||
store = HfAuthStore()
|
||||
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||
store.set_token(tok)
|
||||
# Token is on disk now — a fresh store sees it.
|
||||
assert HfAuthStore().get_token_sync() == tok
|
||||
|
||||
|
||||
def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir):
|
||||
store = HfAuthStore()
|
||||
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||
store.set_token(tok)
|
||||
store.clear()
|
||||
assert not store.has_token()
|
||||
assert HfAuthStore().get_token_sync() is None
|
||||
|
||||
|
||||
async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir):
|
||||
store = HfAuthStore()
|
||||
import time
|
||||
|
||||
tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
|
||||
store.set_token(tok)
|
||||
fetched = await store.get_valid_token()
|
||||
assert fetched == tok
|
||||
|
||||
|
||||
async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir):
|
||||
store = HfAuthStore()
|
||||
import time
|
||||
|
||||
expired = Token(
|
||||
access_token="old",
|
||||
refresh_token="rf",
|
||||
expires_at=time.time() - 100,
|
||||
)
|
||||
store.set_token(expired)
|
||||
refreshed = Token(
|
||||
access_token="new",
|
||||
refresh_token="rf",
|
||||
expires_at=time.time() + 3600,
|
||||
)
|
||||
with patch(
|
||||
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||
new=AsyncMock(return_value=refreshed),
|
||||
):
|
||||
result = await store.get_valid_token()
|
||||
assert result == refreshed
|
||||
|
||||
|
||||
async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir):
|
||||
store = HfAuthStore()
|
||||
import time
|
||||
|
||||
expired = Token(
|
||||
access_token="old",
|
||||
refresh_token="rf",
|
||||
expires_at=time.time() - 100,
|
||||
)
|
||||
store.set_token(expired)
|
||||
with patch(
|
||||
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||
new=AsyncMock(side_effect=RuntimeError("HF down")),
|
||||
):
|
||||
result = await store.get_valid_token()
|
||||
assert result is None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Eligibility
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"listen,multi_user,expected",
|
||||
[
|
||||
("127.0.0.1", False, True),
|
||||
("127.0.0.1", True, False), # multi-user disables it
|
||||
("0.0.0.0", False, False), # bind-all is not loopback
|
||||
("0.0.0.0", True, False),
|
||||
("192.168.1.5", False, False), # LAN address
|
||||
("::1", False, True), # IPv6 loopback
|
||||
],
|
||||
)
|
||||
def test_eligibility(listen, multi_user, expected, monkeypatch):
|
||||
from app.model_downloader.hf_auth import eligibility
|
||||
from comfy.cli_args import args
|
||||
|
||||
monkeypatch.setattr(args, "listen", listen)
|
||||
monkeypatch.setattr(args, "multi_user", multi_user)
|
||||
assert eligibility.is_hf_auth_eligible() is expected
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# gated_detection HF probe
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
async def test_probe_url_hf_public(fresh_auth_store):
|
||||
"""auth_check succeeds with no token → is_hf_downloadable = True."""
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
|
||||
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||
with patch("app.model_downloader.gated_detection._auth_check_sync"), patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=1024),
|
||||
):
|
||||
result = await probe_url(url)
|
||||
assert result.is_hf_downloadable is True
|
||||
assert result.file_size == 1024
|
||||
|
||||
|
||||
async def test_probe_url_hf_gated_no_access(fresh_auth_store):
|
||||
"""auth_check raises GatedRepoError → is_hf_downloadable = False."""
|
||||
from huggingface_hub.errors import GatedRepoError
|
||||
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
|
||||
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
|
||||
fake_response = MagicMock(status_code=403)
|
||||
with patch(
|
||||
"app.model_downloader.gated_detection._auth_check_sync",
|
||||
side_effect=GatedRepoError("gated", response=fake_response),
|
||||
), patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
result = await probe_url(url)
|
||||
assert result.is_hf_downloadable is False
|
||||
|
||||
|
||||
async def test_probe_url_non_hf_skips_auth_check():
|
||||
"""Non-HF URLs never call auth_check; is_hf_downloadable stays None."""
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
|
||||
url = "https://civitai.com/api/download/models/1.safetensors"
|
||||
with patch(
|
||||
"app.model_downloader.gated_detection._auth_check_sync",
|
||||
) as mocked, patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=2048),
|
||||
):
|
||||
result = await probe_url(url)
|
||||
assert result.is_hf_downloadable is None
|
||||
assert result.file_size == 2048
|
||||
mocked.assert_not_called()
|
||||
|
||||
|
||||
async def test_is_gated_cached_across_calls(fresh_auth_store):
|
||||
"""Intrinsic ``is_gated`` should be determined exactly once per URL.
|
||||
|
||||
Subsequent ``probe_url`` calls for the same URL must not re-issue
|
||||
the null-token auth_check — that's the whole point of the cache."""
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
|
||||
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||
with patch(
|
||||
"app.model_downloader.gated_detection._auth_check_sync"
|
||||
) as mocked, patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=1024),
|
||||
):
|
||||
await probe_url(url)
|
||||
await probe_url(url)
|
||||
await probe_url(url)
|
||||
# Three probe_url calls × public-only-needs-1-auth_check = 1 call total.
|
||||
assert mocked.call_count == 1
|
||||
|
||||
|
||||
async def test_file_size_cached_across_calls(fresh_auth_store):
|
||||
"""Once a successful HEAD lands, subsequent calls don't re-HEAD."""
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
|
||||
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||
with patch(
|
||||
"app.model_downloader.gated_detection._auth_check_sync"
|
||||
), patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=2048),
|
||||
) as size_probe:
|
||||
r1 = await probe_url(url)
|
||||
r2 = await probe_url(url)
|
||||
assert r1.file_size == 2048
|
||||
assert r2.file_size == 2048
|
||||
assert size_probe.call_count == 1
|
||||
|
||||
|
||||
async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store):
|
||||
"""When ``is_hf_downloadable`` is False we must NOT HEAD the URL —
|
||||
otherwise a 401-due-to-gating would land as a cached ``None`` that
|
||||
survives a later successful login."""
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
from huggingface_hub.errors import GatedRepoError
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
|
||||
fake_resp = MagicMock(status_code=403)
|
||||
with patch(
|
||||
"app.model_downloader.gated_detection._auth_check_sync",
|
||||
side_effect=GatedRepoError("gated", response=fake_resp),
|
||||
), patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=None),
|
||||
) as size_probe:
|
||||
result = await probe_url(url)
|
||||
assert result.is_hf_downloadable is False
|
||||
assert result.file_size is None
|
||||
assert size_probe.call_count == 0
|
||||
|
||||
|
||||
async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir):
|
||||
"""For a gated URL, auth_check runs twice: once with token=None to
|
||||
determine the intrinsic ``is_gated`` flag (cached forever), and once
|
||||
with the stored access_token to determine ``is_hf_downloadable`` for
|
||||
the current user."""
|
||||
from app.model_downloader import gated_detection
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
from huggingface_hub.errors import GatedRepoError
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
gated_detection.clear_caches_for_tests()
|
||||
fresh_auth_store.set_token(Token(
|
||||
access_token="hf_test_token",
|
||||
refresh_token=None,
|
||||
expires_at=9999999999.0,
|
||||
))
|
||||
url = "https://huggingface.co/private/repo/resolve/main/x.safetensors"
|
||||
|
||||
fake_resp = MagicMock(status_code=403)
|
||||
|
||||
def fake_auth_check(repo_id, token):
|
||||
# Null-token call → repo is gated. Subsequent call with the real
|
||||
# token succeeds (user has access).
|
||||
if token is None:
|
||||
raise GatedRepoError("gated", response=fake_resp)
|
||||
|
||||
with patch(
|
||||
"app.model_downloader.gated_detection._auth_check_sync",
|
||||
side_effect=fake_auth_check,
|
||||
) as mocked, patch(
|
||||
"app.model_downloader.gated_detection._probe_size_once",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
result = await probe_url(url)
|
||||
|
||||
# is_hf_downloadable should be True (token-authed call succeeded).
|
||||
assert result.is_hf_downloadable is True
|
||||
# Two calls: (repo_id, None) then (repo_id, <token>).
|
||||
assert mocked.call_count == 2
|
||||
assert mocked.call_args_list[0].args == ("private/repo", None)
|
||||
assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# OAuth primitives
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_make_pkce_returns_distinct_high_entropy_values():
|
||||
verifier1, challenge1, state1 = oauth._make_pkce()
|
||||
verifier2, challenge2, state2 = oauth._make_pkce()
|
||||
assert verifier1 != verifier2
|
||||
assert challenge1 != challenge2
|
||||
assert state1 != state2
|
||||
# Verifier should be at least 43 chars per PKCE spec.
|
||||
assert len(verifier1) >= 43
|
||||
|
||||
|
||||
def test_build_authorize_url_includes_pkce_and_state():
|
||||
url = oauth._build_authorize_url("challenge123", "state456")
|
||||
assert url.startswith(oauth.AUTHORIZE_URL)
|
||||
assert "client_id=" + oauth.HF_CLIENT_ID in url
|
||||
assert "code_challenge=challenge123" in url
|
||||
assert "code_challenge_method=S256" in url
|
||||
assert "state=state456" in url
|
||||
assert "response_type=code" in url
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Routes
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
async def test_hf_auth_token_status_empty(aiohttp_client, app):
|
||||
"""No token set → token_available=false, username=null."""
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/hf-auth-token-status")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data == {"token_available": False, "username": None}
|
||||
|
||||
|
||||
async def test_hf_auth_token_status_with_token(
|
||||
aiohttp_client, app, fresh_auth_store, patched_user_dir
|
||||
):
|
||||
"""Token present, whoami works → username is returned."""
|
||||
fresh_auth_store.set_token(Token(
|
||||
access_token="x", refresh_token=None, expires_at=9999999999.0,
|
||||
))
|
||||
with patch(
|
||||
"app.model_downloader.api.routes._whoami_username",
|
||||
return_value="alice",
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/api/hf-auth-token-status")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"token_available": True, "username": "alice"}
|
||||
|
||||
|
||||
async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch):
|
||||
"""Not loopback / multi-user → 403."""
|
||||
monkeypatch.setattr(
|
||||
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||
lambda: False,
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/api/hf-auth-login-start")
|
||||
assert resp.status == 403
|
||||
assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE"
|
||||
|
||||
|
||||
async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch):
|
||||
"""Eligible + first attempt → 200 with authorize_url."""
|
||||
monkeypatch.setattr(
|
||||
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.model_downloader.api.routes.start_login_flow",
|
||||
AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"),
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/api/hf-auth-login-start")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["authorize_url"].startswith(
|
||||
"https://huggingface.co/oauth/authorize"
|
||||
)
|
||||
|
||||
|
||||
async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch):
|
||||
"""Lock already held → 409."""
|
||||
from app.model_downloader.hf_auth.oauth import OAuthInProgressError
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||
lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.model_downloader.api.routes.start_login_flow",
|
||||
AsyncMock(side_effect=OAuthInProgressError()),
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/api/hf-auth-login-start")
|
||||
assert resp.status == 409
|
||||
assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS"
|
||||
|
||||
|
||||
async def test_hf_auth_logout_clears_store(
|
||||
aiohttp_client, app, fresh_auth_store, patched_user_dir
|
||||
):
|
||||
fresh_auth_store.set_token(Token(
|
||||
access_token="x", refresh_token=None, expires_at=9999999999.0,
|
||||
))
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/api/hf-auth-logout")
|
||||
assert resp.status == 200
|
||||
assert (await resp.json()) == {"logged_out": True}
|
||||
assert not fresh_auth_store.has_token()
|
||||
|
||||
|
||||
async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch):
|
||||
"""The availability response embeds {token_available, eligible}."""
|
||||
monkeypatch.setattr(
|
||||
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||
lambda: True,
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/models-availability-status",
|
||||
json={"models": {}},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert "hf_auth" in data
|
||||
assert data["hf_auth"] == {"token_available": False, "eligible": True}
|
||||
504
tests-unit/app_test/model_downloader_test.py
Normal file
504
tests-unit/app_test/model_downloader_test.py
Normal file
@ -0,0 +1,504 @@
|
||||
"""Unit tests for the server-side model download subsystem.
|
||||
|
||||
Covers the pieces that don't require talking to a real network:
|
||||
|
||||
- path parsing & allowlist (pure functions)
|
||||
- DownloadServer registry lifecycle (in-memory state)
|
||||
- API routes via aiohttp_client + folder_paths/probe_url patches
|
||||
|
||||
Streaming downloads themselves are exercised indirectly — the route-level
|
||||
tests stub out the network probe so we can verify the gating logic in
|
||||
``download_models`` without making real HTTP calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from app.model_downloader.allowlist import is_url_allowed
|
||||
from app.model_downloader.api.routes import register_routes
|
||||
from app.model_downloader.download_server import DownloadServer
|
||||
from app.model_downloader.gated_detection import MetadataProbeResult
|
||||
from app.model_downloader.paths import (
|
||||
InvalidModelId,
|
||||
parse_model_id,
|
||||
resolve_destination,
|
||||
resolve_existing,
|
||||
)
|
||||
|
||||
# Global asyncio mark: the sync tests below trigger a cosmetic
|
||||
# PytestWarning for each one because pytest-asyncio applies the mark
|
||||
# indiscriminately. Other tests in this repo (see custom_node_manager_test.py)
|
||||
# use the same pattern. The warnings are noise, not failures.
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Fixtures
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_root(tmp_path):
|
||||
"""A fake ``models/`` root with two registered folder types."""
|
||||
loras_dir = tmp_path / "loras"
|
||||
checkpoints_dir = tmp_path / "checkpoints"
|
||||
loras_dir.mkdir()
|
||||
checkpoints_dir.mkdir()
|
||||
return tmp_path, loras_dir, checkpoints_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_folder_paths(model_root):
|
||||
"""Point folder_paths at our fake roots for the duration of one test."""
|
||||
_root, loras_dir, checkpoints_dir = model_root
|
||||
mapping = {
|
||||
"loras": ([str(loras_dir)], {".safetensors"}),
|
||||
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
|
||||
}
|
||||
with patch(
|
||||
"folder_paths.folder_names_and_paths", mapping
|
||||
), patch(
|
||||
"folder_paths.get_folder_paths",
|
||||
side_effect=lambda name: mapping.get(name, ([], set()))[0],
|
||||
):
|
||||
yield mapping
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_download_server():
|
||||
"""Reset the module-level singleton between tests so registry state
|
||||
doesn't leak across tests sharing the singleton."""
|
||||
from app.model_downloader.download_server import DOWNLOAD_SERVER
|
||||
|
||||
DOWNLOAD_SERVER.reset_for_tests()
|
||||
yield DOWNLOAD_SERVER
|
||||
DOWNLOAD_SERVER.reset_for_tests()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(patched_folder_paths, fresh_download_server):
|
||||
app = web.Application()
|
||||
register_routes(app)
|
||||
return app
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Pure helpers: allowlist + path parsing
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_allowlist_accepts_hf_safetensors():
|
||||
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
|
||||
|
||||
|
||||
def test_allowlist_accepts_civitai_pth():
|
||||
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
|
||||
|
||||
|
||||
def test_allowlist_rejects_unknown_host():
|
||||
assert not is_url_allowed("https://example.com/x.safetensors")
|
||||
|
||||
|
||||
def test_allowlist_rejects_api_path_on_hf():
|
||||
# On an allowlisted host but not pointing at a model file.
|
||||
assert not is_url_allowed("https://huggingface.co/api/models")
|
||||
|
||||
|
||||
def test_allowlist_rejects_non_https_except_localhost():
|
||||
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
|
||||
assert is_url_allowed("http://localhost:8000/x.safetensors")
|
||||
|
||||
|
||||
def test_parse_model_id_valid(patched_folder_paths):
|
||||
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
|
||||
|
||||
|
||||
def test_parse_model_id_rejects_traversal(patched_folder_paths):
|
||||
with pytest.raises(InvalidModelId):
|
||||
parse_model_id("../etc/passwd")
|
||||
|
||||
|
||||
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
|
||||
with pytest.raises(InvalidModelId):
|
||||
parse_model_id("nope/x.safetensors")
|
||||
|
||||
|
||||
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
|
||||
with pytest.raises(InvalidModelId):
|
||||
parse_model_id("loras/sub/x.safetensors")
|
||||
|
||||
|
||||
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
|
||||
_root, loras_dir, _ = model_root
|
||||
target = loras_dir / "foo.safetensors"
|
||||
target.write_bytes(b"x")
|
||||
assert resolve_existing("loras/foo.safetensors") == str(target)
|
||||
|
||||
|
||||
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
|
||||
assert resolve_existing("loras/missing.safetensors") is None
|
||||
|
||||
|
||||
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
|
||||
_root, loras_dir, _ = model_root
|
||||
final, tmp = resolve_destination("loras/foo.safetensors")
|
||||
assert final == str(loras_dir / "foo.safetensors")
|
||||
assert tmp == final + ".tmp"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def test_register_is_exclusive():
|
||||
server = DownloadServer()
|
||||
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
|
||||
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
|
||||
assert s1 is not None
|
||||
assert s2 is None
|
||||
assert server.is_downloading("loras/x.safetensors")
|
||||
|
||||
|
||||
def test_cancel_removes_session():
|
||||
server = DownloadServer()
|
||||
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
|
||||
assert server.cancel("loras/x.safetensors") is True
|
||||
assert not server.is_downloading("loras/x.safetensors")
|
||||
|
||||
|
||||
def test_cancel_returns_false_when_absent():
|
||||
server = DownloadServer()
|
||||
assert server.cancel("loras/never.safetensors") is False
|
||||
|
||||
|
||||
def test_finish_only_clears_matching_epoch():
|
||||
"""If a session is cancelled and a new one for the same id is
|
||||
registered, ``finish`` from the original worker must not evict the
|
||||
newer session."""
|
||||
server = DownloadServer()
|
||||
s_old = server.try_register("loras/x.safetensors", "u1")
|
||||
server.cancel("loras/x.safetensors")
|
||||
s_new = server.try_register("loras/x.safetensors", "u2")
|
||||
assert s_new is not None and s_new.epoch != s_old.epoch
|
||||
# Old worker's late finish() is a no-op:
|
||||
server.finish(s_old)
|
||||
assert server.is_downloading("loras/x.safetensors")
|
||||
server.finish(s_new)
|
||||
assert not server.is_downloading("loras/x.safetensors")
|
||||
|
||||
|
||||
def test_is_active_follows_cancellation():
|
||||
server = DownloadServer()
|
||||
s = server.try_register("loras/x.safetensors", "u")
|
||||
assert server.is_active(s)
|
||||
server.cancel("loras/x.safetensors")
|
||||
assert not server.is_active(s)
|
||||
|
||||
|
||||
def test_update_progress_tracks_fraction():
|
||||
server = DownloadServer()
|
||||
s = server.try_register("loras/x.safetensors", "u")
|
||||
server.update_progress(s, 50, 100)
|
||||
snap = server.snapshot()["loras/x.safetensors"]
|
||||
assert snap.bytes_downloaded == 50
|
||||
assert snap.total_bytes == 100
|
||||
assert snap.progress == 0.5
|
||||
|
||||
|
||||
def test_update_progress_with_unknown_total_keeps_progress_none():
|
||||
server = DownloadServer()
|
||||
s = server.try_register("loras/x.safetensors", "u")
|
||||
server.update_progress(s, 50, None)
|
||||
assert server.snapshot()["loras/x.safetensors"].progress is None
|
||||
|
||||
|
||||
def test_cleanup_orphan_tmp_files(model_root):
|
||||
"""Orphan .tmp left by a crashed download must be swept on first use."""
|
||||
_root, loras_dir, _ = model_root
|
||||
orphan = loras_dir / "stale.safetensors.tmp"
|
||||
orphan.write_bytes(b"partial")
|
||||
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
|
||||
with patch("folder_paths.folder_names_and_paths", mapping), patch(
|
||||
"folder_paths.get_folder_paths",
|
||||
side_effect=lambda name: mapping.get(name, ([], set()))[0],
|
||||
):
|
||||
server = DownloadServer()
|
||||
assert orphan.exists(), "sweep must not run at construction time"
|
||||
server.sweep_orphan_tmp_files()
|
||||
assert not orphan.exists()
|
||||
# Idempotent — a second call is a cheap no-op.
|
||||
server.sweep_orphan_tmp_files()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Route: POST /api/models-availability-status
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
async def test_availability_partitions_correctly(
|
||||
aiohttp_client, app, model_root, fresh_download_server
|
||||
):
|
||||
_root, loras_dir, _ = model_root
|
||||
(loras_dir / "present.safetensors").write_bytes(b"x")
|
||||
fresh_download_server.try_register(
|
||||
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Stub probes — we're testing state assignment, not network calls.
|
||||
with patch(
|
||||
"app.model_downloader.api.routes.probe_url",
|
||||
new=AsyncMock(return_value=MetadataProbeResult(
|
||||
file_size=None, is_hf_downloadable=None,
|
||||
)),
|
||||
):
|
||||
body = {
|
||||
"models": {
|
||||
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
|
||||
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
|
||||
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
|
||||
}
|
||||
}
|
||||
resp = await client.post("/api/models-availability-status", json=body)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
models = data["models"]
|
||||
assert models["loras/present.safetensors"]["state"] == "available"
|
||||
assert models["loras/missing.safetensors"]["state"] == "missing"
|
||||
assert models["loras/inflight.safetensors"]["state"] == "downloading"
|
||||
assert "hf_auth" in data
|
||||
|
||||
|
||||
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
with patch(
|
||||
"app.model_downloader.api.routes.probe_url",
|
||||
new=AsyncMock(return_value=MetadataProbeResult(
|
||||
file_size=None, is_hf_downloadable=None,
|
||||
)),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/api/models-availability-status",
|
||||
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["models"]["../etc/passwd"]["state"] == "missing"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Route: POST /api/download-models — precondition gating
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
|
||||
)
|
||||
assert resp.status == 400
|
||||
err = (await resp.json())["error"]
|
||||
assert err["code"] == "URL_NOT_ALLOWED"
|
||||
|
||||
|
||||
async def test_download_rejects_already_available(
|
||||
aiohttp_client, app, model_root
|
||||
):
|
||||
_root, loras_dir, _ = model_root
|
||||
(loras_dir / "x.safetensors").write_bytes(b"x")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={"models": {
|
||||
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
|
||||
}},
|
||||
)
|
||||
assert resp.status == 409
|
||||
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
|
||||
|
||||
|
||||
async def test_download_rejects_already_downloading(
|
||||
aiohttp_client, app, fresh_download_server
|
||||
):
|
||||
fresh_download_server.try_register(
|
||||
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={"models": {
|
||||
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
|
||||
}},
|
||||
)
|
||||
assert resp.status == 409
|
||||
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
|
||||
|
||||
|
||||
async def test_download_rejects_gated_model(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
with patch(
|
||||
"app.model_downloader.api.routes.probe_url",
|
||||
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
|
||||
):
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={"models": {
|
||||
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
|
||||
}},
|
||||
)
|
||||
assert resp.status == 400
|
||||
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
|
||||
|
||||
|
||||
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
|
||||
)
|
||||
assert resp.status == 400
|
||||
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
|
||||
|
||||
|
||||
async def test_download_atomic_failure_does_not_register_partial(
|
||||
aiohttp_client, app, model_root, fresh_download_server
|
||||
):
|
||||
"""If one model in a batch fails, none get registered."""
|
||||
_root, loras_dir, _ = model_root
|
||||
(loras_dir / "already.safetensors").write_bytes(b"x")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={
|
||||
"models": {
|
||||
"loras/already.safetensors":
|
||||
"https://huggingface.co/a/b/resolve/main/already.safetensors",
|
||||
"loras/new.safetensors":
|
||||
"https://huggingface.co/a/b/resolve/main/new.safetensors",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert resp.status == 409
|
||||
# The "new" model should not have been registered as part of the
|
||||
# failed batch.
|
||||
assert not fresh_download_server.is_downloading("loras/new.safetensors")
|
||||
|
||||
|
||||
async def test_download_schedules_when_all_preconditions_pass(
|
||||
aiohttp_client, app, fresh_download_server
|
||||
):
|
||||
"""Verify the precondition pass, registration pass, and async
|
||||
scheduling all wire up correctly. We patch the streamer to avoid
|
||||
real HTTP while still letting the route execute end-to-end."""
|
||||
started = asyncio.Event()
|
||||
finish_signal = asyncio.Event()
|
||||
|
||||
async def fake_stream(session):
|
||||
started.set()
|
||||
await finish_signal.wait()
|
||||
from app.model_downloader.download_server import DOWNLOAD_SERVER
|
||||
DOWNLOAD_SERVER.finish(session)
|
||||
return "/dev/null"
|
||||
|
||||
with patch(
|
||||
"app.model_downloader.api.routes.probe_url",
|
||||
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
|
||||
), patch(
|
||||
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/download-models",
|
||||
json={"models": {
|
||||
"loras/new.safetensors":
|
||||
"https://huggingface.co/a/b/resolve/main/new.safetensors"
|
||||
}},
|
||||
)
|
||||
assert resp.status == 202
|
||||
body = await resp.json()
|
||||
assert body["accepted"] is True
|
||||
assert body["scheduled"] == ["loras/new.safetensors"]
|
||||
# Wait for the worker to actually start.
|
||||
await asyncio.wait_for(started.wait(), timeout=2.0)
|
||||
assert fresh_download_server.is_downloading("loras/new.safetensors")
|
||||
finish_signal.set()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Route: POST /api/cancel-model-download-session
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
async def test_cancel_removes_active_session(
|
||||
aiohttp_client, app, fresh_download_server
|
||||
):
|
||||
fresh_download_server.try_register(
|
||||
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
|
||||
)
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/cancel-model-download-session",
|
||||
json={"model_id": "loras/x.safetensors"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert (await resp.json())["cancelled"] is True
|
||||
assert not fresh_download_server.is_downloading("loras/x.safetensors")
|
||||
|
||||
|
||||
async def test_cancel_returns_404_when_none(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/cancel-model-download-session",
|
||||
json={"model_id": "loras/nothing.safetensors"},
|
||||
)
|
||||
assert resp.status == 404
|
||||
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Unified availability response embeds metadata per id
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
async def test_availability_embeds_metadata(aiohttp_client, app):
|
||||
"""``file_size`` + ``is_hf_downloadable`` come back on the same
|
||||
request as the state — no separate metadata endpoint."""
|
||||
results = {
|
||||
"https://huggingface.co/a/b/resolve/main/free.safetensors":
|
||||
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
|
||||
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
|
||||
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
|
||||
}
|
||||
|
||||
async def fake_probe(url):
|
||||
return results[url]
|
||||
|
||||
with patch(
|
||||
"app.model_downloader.api.routes.probe_url", new=fake_probe
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/api/models-availability-status",
|
||||
json={
|
||||
"models": {
|
||||
"loras/free.safetensors":
|
||||
"https://huggingface.co/a/b/resolve/main/free.safetensors",
|
||||
"loras/gated.safetensors":
|
||||
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
models = (await resp.json())["models"]
|
||||
assert models["loras/free.safetensors"]["file_size"] == 1024
|
||||
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
|
||||
assert models["loras/gated.safetensors"]["file_size"] is None
|
||||
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False
|
||||
Loading…
Reference in New Issue
Block a user