mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Merge c8b3d54cf8 into 96e0e3585b
This commit is contained in:
commit
6f5cd02796
51
app/model_downloader/allowlist.py
Normal file
51
app/model_downloader/allowlist.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""URL allowlist for server-side model fetches.
|
||||||
|
|
||||||
|
Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows
|
||||||
|
agree on which URLs are eligible for download. Server-side allowlisting is
|
||||||
|
the primary SSRF defense for this subsystem — workflow JSON is untrusted
|
||||||
|
input (anyone can hand-craft one), so we never let the server fetch URLs
|
||||||
|
outside this list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
# Frontend parity: ``missingModelDownload-*.js`` exports the same triple
|
||||||
|
# (Civitai / HuggingFace / localhost). Keyed by exact hostname → allowed
|
||||||
|
# schemes, and matched against the *parsed* host (not a raw string prefix),
|
||||||
|
# so URL-userinfo tricks can't slip past — see ``is_url_allowed``.
|
||||||
|
_ALLOWED_HOSTS = {
|
||||||
|
"huggingface.co": {"https"},
|
||||||
|
"civitai.com": {"https"},
|
||||||
|
"localhost": {"http"},
|
||||||
|
"127.0.0.1": {"http"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Frontend parity: same set as ``a = [...]`` in the bundle.
|
||||||
|
_ALLOWED_MODEL_EXTENSIONS = (
|
||||||
|
".safetensors",
|
||||||
|
".sft",
|
||||||
|
".ckpt",
|
||||||
|
".pth",
|
||||||
|
".pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_url_allowed(url: str) -> bool:
|
||||||
|
"""Check whether ``url`` is permitted as a server-side download source.
|
||||||
|
|
||||||
|
True only when the parsed host + scheme are allowlisted AND the path ends
|
||||||
|
in a model extension. Matching on ``parsed.hostname`` (not a string prefix)
|
||||||
|
defeats userinfo tricks like ``http://127.0.0.1:80@169.254.169.254/x.safetensors``,
|
||||||
|
whose real host is ``169.254.169.254``; the extension check rejects non-model
|
||||||
|
URLs on allowed hosts (e.g. ``huggingface.co/api/...``).
|
||||||
|
"""
|
||||||
|
if not isinstance(url, str) or not url:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
host = parsed.hostname
|
||||||
|
if host is None or parsed.scheme not in _ALLOWED_HOSTS.get(host, ()):
|
||||||
|
return False
|
||||||
|
return any(parsed.path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS)
|
||||||
359
app/model_downloader/api/routes.py
Normal file
359
app/model_downloader/api/routes.py
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
"""Aiohttp routes for the server-side model download subsystem.
|
||||||
|
|
||||||
|
Endpoint surface (all under ``/api/``, all kebab-case):
|
||||||
|
|
||||||
|
- ``POST /api/models-availability-status`` — bulk status + metadata query.
|
||||||
|
- ``POST /api/download-models`` — start a batch of downloads.
|
||||||
|
- ``POST /api/cancel-model-download-session`` — cancel a single in-flight one.
|
||||||
|
- ``GET /api/hf-auth-token-status`` — current HF login state.
|
||||||
|
- ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow.
|
||||||
|
- ``POST /api/hf-auth-logout`` — drop the stored HF token.
|
||||||
|
|
||||||
|
The contract is intentionally narrow: only model_ids of the form
|
||||||
|
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
|
||||||
|
are accepted, and only URLs on the same allowlist the frontend already
|
||||||
|
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
|
||||||
|
to keep the server out of the SSRF business for this feature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from app.model_downloader.allowlist import is_url_allowed
|
||||||
|
from app.model_downloader.download_server import (
|
||||||
|
DOWNLOAD_SERVER,
|
||||||
|
DownloadSession,
|
||||||
|
)
|
||||||
|
from app.model_downloader.downloader import schedule_batch
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||||
|
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
|
||||||
|
from app.model_downloader.hf_auth.oauth import (
|
||||||
|
OAuthCallbackError,
|
||||||
|
OAuthInProgressError,
|
||||||
|
start_login_flow,
|
||||||
|
)
|
||||||
|
from app.model_downloader.paths import (
|
||||||
|
InvalidModelId,
|
||||||
|
parse_model_id,
|
||||||
|
resolve_existing,
|
||||||
|
)
|
||||||
|
from app.model_downloader.api import schemas_in, schemas_out
|
||||||
|
|
||||||
|
ROUTES = web.RouteTableDef()
|
||||||
|
|
||||||
|
|
||||||
|
def register_routes(app: web.Application) -> None:
|
||||||
|
"""Wire the model-downloader routes into the running aiohttp app.
|
||||||
|
|
||||||
|
Called once from ``server.py`` during ``PromptServer`` startup.
|
||||||
|
"""
|
||||||
|
app.add_routes(ROUTES)
|
||||||
|
|
||||||
|
|
||||||
|
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
|
||||||
|
|
||||||
|
|
||||||
|
ErrorCode = Literal[
|
||||||
|
"INVALID_JSON",
|
||||||
|
"INVALID_BODY",
|
||||||
|
"EMPTY_REQUEST",
|
||||||
|
"INVALID_MODEL_ID",
|
||||||
|
"URL_NOT_ALLOWED",
|
||||||
|
"ALREADY_AVAILABLE",
|
||||||
|
"ALREADY_DOWNLOADING",
|
||||||
|
"MODEL_NOT_DOWNLOADABLE",
|
||||||
|
"NOT_DOWNLOADING",
|
||||||
|
"HF_AUTH_NOT_ELIGIBLE",
|
||||||
|
"HF_AUTH_IN_PROGRESS",
|
||||||
|
"HF_AUTH_START_FAILED",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _error(status: int, code: ErrorCode, message: str, details: dict | None = None) -> web.Response:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validation_error(code: ErrorCode, ve: ValidationError) -> web.Response:
|
||||||
|
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
|
||||||
|
|
||||||
|
|
||||||
|
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
|
||||||
|
return web.json_response(
|
||||||
|
payload.model_dump(mode="json", exclude_none=False),
|
||||||
|
status=status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
|
||||||
|
"""Parse a JSON body into a pydantic model or raise a 400 response."""
|
||||||
|
try:
|
||||||
|
raw = await request.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
try:
|
||||||
|
return model.model_validate(raw)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error("INVALID_BODY", ve)
|
||||||
|
|
||||||
|
|
||||||
|
# ----- 1. availability status (unified: state + metadata per id) -----
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/models-availability-status")
|
||||||
|
async def models_availability_status(request: web.Request) -> web.Response:
|
||||||
|
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
|
||||||
|
|
||||||
|
State (``available`` / ``missing`` / ``downloading``) is cheap to
|
||||||
|
recompute per call. ``file_size`` and ``is_gated`` are cached
|
||||||
|
server-side per URL. ``is_hf_downloadable`` is recomputed every
|
||||||
|
call from the current token state — that's what makes login + license
|
||||||
|
acceptance show up in the UI within one poll cycle without any
|
||||||
|
frontend cache plumbing.
|
||||||
|
"""
|
||||||
|
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
|
||||||
|
if isinstance(parsed, web.Response):
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
items = list(parsed.models.items())
|
||||||
|
|
||||||
|
# Run all probes concurrently; each is internally cached per URL.
|
||||||
|
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
|
||||||
|
|
||||||
|
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
|
||||||
|
for (model_id, _url), probe in zip(items, probes):
|
||||||
|
try:
|
||||||
|
parse_model_id(model_id)
|
||||||
|
except InvalidModelId:
|
||||||
|
# Ill-formed identifier: report as missing without 400-ing the
|
||||||
|
# whole batch — the workflow author probably typo'd.
|
||||||
|
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||||
|
state="missing",
|
||||||
|
file_size=probe.file_size,
|
||||||
|
is_hf_downloadable=probe.is_hf_downloadable,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
active = DOWNLOAD_SERVER.get(model_id)
|
||||||
|
if active is not None:
|
||||||
|
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||||
|
state="downloading",
|
||||||
|
progress=schemas_out.DownloadProgress(
|
||||||
|
bytes_downloaded=active.bytes_downloaded,
|
||||||
|
total_bytes=active.total_bytes,
|
||||||
|
progress=active.progress,
|
||||||
|
),
|
||||||
|
file_size=probe.file_size,
|
||||||
|
is_hf_downloadable=probe.is_hf_downloadable,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
state: schemas_out.ModelState = (
|
||||||
|
"available" if resolve_existing(model_id) is not None else "missing"
|
||||||
|
)
|
||||||
|
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||||
|
state=state,
|
||||||
|
file_size=probe.file_size,
|
||||||
|
is_hf_downloadable=probe.is_hf_downloadable,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _ok(schemas_out.AvailabilityStatusResponse(
|
||||||
|
models=response_models,
|
||||||
|
hf_auth=schemas_out.HfAuthStatus(
|
||||||
|
token_available=HF_AUTH_STORE.has_token(),
|
||||||
|
eligible=is_hf_auth_eligible(),
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
# ----- 2. start downloads -----
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/download-models")
|
||||||
|
async def download_models(request: web.Request) -> web.Response:
|
||||||
|
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
|
||||||
|
if isinstance(parsed, web.Response):
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
if not parsed.models:
|
||||||
|
return _error(400, "EMPTY_REQUEST", "No models supplied.")
|
||||||
|
|
||||||
|
# ----- precondition pass: validate everything BEFORE registering anything -----
|
||||||
|
# Atomic semantics: if any model fails any precondition (invalid id,
|
||||||
|
# not allow-listed URL, already on disk, already downloading, or gated),
|
||||||
|
# the entire request fails and no state is changed.
|
||||||
|
requested = list(parsed.models.items())
|
||||||
|
|
||||||
|
for model_id, url in requested:
|
||||||
|
try:
|
||||||
|
parse_model_id(model_id)
|
||||||
|
except InvalidModelId as e:
|
||||||
|
return _error(400, "INVALID_MODEL_ID", str(e),
|
||||||
|
{"model_id": model_id})
|
||||||
|
|
||||||
|
if not is_url_allowed(url):
|
||||||
|
return _error(
|
||||||
|
400, "URL_NOT_ALLOWED",
|
||||||
|
"Server-side downloads only accept HuggingFace, Civitai, "
|
||||||
|
"or localhost URLs ending in a known model extension.",
|
||||||
|
{"model_id": model_id, "url": url},
|
||||||
|
)
|
||||||
|
|
||||||
|
if resolve_existing(model_id) is not None:
|
||||||
|
return _error(409, "ALREADY_AVAILABLE",
|
||||||
|
f"Model already exists on disk: {model_id}",
|
||||||
|
{"model_id": model_id})
|
||||||
|
|
||||||
|
if DOWNLOAD_SERVER.is_downloading(model_id):
|
||||||
|
return _error(409, "ALREADY_DOWNLOADING",
|
||||||
|
f"A download for {model_id} is already in progress.",
|
||||||
|
{"model_id": model_id})
|
||||||
|
|
||||||
|
# Reachability check last — it's the only one that talks to the
|
||||||
|
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
|
||||||
|
# reflects current token access; for non-HF URLs it's None, and we
|
||||||
|
# treat that as "no info, proceed".
|
||||||
|
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
|
||||||
|
for (model_id, url), probe in zip(requested, probes):
|
||||||
|
if probe.is_hf_downloadable is False:
|
||||||
|
return _error(
|
||||||
|
400, "MODEL_NOT_DOWNLOADABLE",
|
||||||
|
f"Model {model_id} is gated on HuggingFace and the current "
|
||||||
|
f"server token (if any) does not grant access.",
|
||||||
|
{"model_id": model_id, "url": url},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----- registration pass: try_register is atomic per model_id -----
|
||||||
|
# Defensive: another request might have raced past our pre-check
|
||||||
|
# between the loop above and here. try_register handles that.
|
||||||
|
sessions: list[DownloadSession] = []
|
||||||
|
for model_id, url in requested:
|
||||||
|
session = DOWNLOAD_SERVER.try_register(model_id, url)
|
||||||
|
if session is None:
|
||||||
|
# Race: someone else got in. Roll back what we registered.
|
||||||
|
for s in sessions:
|
||||||
|
DOWNLOAD_SERVER.cancel(s.model_id)
|
||||||
|
return _error(409, "ALREADY_DOWNLOADING",
|
||||||
|
f"A download for {model_id} is already in progress (race).",
|
||||||
|
{"model_id": model_id})
|
||||||
|
sessions.append(session)
|
||||||
|
|
||||||
|
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
|
||||||
|
schedule_batch(sessions)
|
||||||
|
logging.info(
|
||||||
|
"[model_downloader] scheduled %d downloads: %s",
|
||||||
|
len(sessions), [s.model_id for s in sessions],
|
||||||
|
)
|
||||||
|
|
||||||
|
return _ok(schemas_out.DownloadModelsResponse(
|
||||||
|
accepted=True,
|
||||||
|
scheduled=[s.model_id for s in sessions],
|
||||||
|
), status=202)
|
||||||
|
|
||||||
|
|
||||||
|
# ----- 3. cancel a session -----
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/cancel-model-download-session")
|
||||||
|
async def cancel_model_download_session(request: web.Request) -> web.Response:
|
||||||
|
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
|
||||||
|
if isinstance(parsed, web.Response):
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
try:
|
||||||
|
parse_model_id(parsed.model_id)
|
||||||
|
except InvalidModelId as e:
|
||||||
|
return _error(400, "INVALID_MODEL_ID", str(e), {"model_id": parsed.model_id})
|
||||||
|
|
||||||
|
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
|
||||||
|
if not cancelled:
|
||||||
|
return _error(404, "NOT_DOWNLOADING",
|
||||||
|
f"No active download for {parsed.model_id}.",
|
||||||
|
{"model_id": parsed.model_id})
|
||||||
|
|
||||||
|
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
|
||||||
|
|
||||||
|
|
||||||
|
# ----- 4. HuggingFace OAuth status / login start / logout -----
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/hf-auth-token-status")
|
||||||
|
async def hf_auth_token_status(request: web.Request) -> web.Response:
|
||||||
|
"""Return whether the server holds a usable HF token + its username.
|
||||||
|
|
||||||
|
Used by the settings UI and (out-of-band) by the frontend on
|
||||||
|
login completion. ``token_available`` is true even if the cached
|
||||||
|
access_token is expired — as long as a refresh_token exists, the
|
||||||
|
user is "logged in" from their perspective.
|
||||||
|
"""
|
||||||
|
token_present = HF_AUTH_STORE.has_token()
|
||||||
|
username: Optional[str] = None
|
||||||
|
if token_present:
|
||||||
|
# Resolve the username via whoami. Done in a worker thread because
|
||||||
|
# huggingface_hub's whoami is synchronous + blocks on a network call.
|
||||||
|
tok = await HF_AUTH_STORE.get_valid_token()
|
||||||
|
if tok is not None:
|
||||||
|
try:
|
||||||
|
username = await asyncio.to_thread(_whoami_username, tok.access_token)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug("[hf_auth] whoami failed: %s", e)
|
||||||
|
return _ok(schemas_out.HfAuthTokenStatusResponse(
|
||||||
|
token_available=token_present,
|
||||||
|
username=username,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def _whoami_username(token: str) -> Optional[str]:
|
||||||
|
"""Sync helper: ask HF for the user name attached to a token."""
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
info = HfApi().whoami(token=token)
|
||||||
|
if isinstance(info, dict):
|
||||||
|
return info.get("name") or info.get("fullname")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/hf-auth-login-start")
|
||||||
|
async def hf_auth_login_start(request: web.Request) -> web.Response:
|
||||||
|
"""Begin one OAuth attempt: bind the callback port, return the URL.
|
||||||
|
|
||||||
|
Rejected outright if this deployment isn't eligible (we don't
|
||||||
|
surface the option on multi-tenant / public-IP installs).
|
||||||
|
"""
|
||||||
|
if not is_hf_auth_eligible():
|
||||||
|
return _error(
|
||||||
|
403, "HF_AUTH_NOT_ELIGIBLE",
|
||||||
|
"This server is not eligible for interactive HuggingFace login. "
|
||||||
|
"It must be bound to a loopback address and not running in "
|
||||||
|
"--multi-user mode.",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
url = await start_login_flow()
|
||||||
|
except OAuthInProgressError:
|
||||||
|
return _error(
|
||||||
|
409, "HF_AUTH_IN_PROGRESS",
|
||||||
|
"Another HuggingFace login attempt is in progress. Try again "
|
||||||
|
"after it completes or times out.",
|
||||||
|
)
|
||||||
|
except OAuthCallbackError as e:
|
||||||
|
return _error(
|
||||||
|
503, "HF_AUTH_START_FAILED",
|
||||||
|
f"Could not start the HuggingFace login flow: {e}",
|
||||||
|
)
|
||||||
|
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/hf-auth-logout")
|
||||||
|
async def hf_auth_logout(request: web.Request) -> web.Response:
|
||||||
|
"""Drop the in-memory + on-disk HF token."""
|
||||||
|
HF_AUTH_STORE.clear()
|
||||||
|
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))
|
||||||
41
app/model_downloader/api/schemas_in.py
Normal file
41
app/model_downloader/api/schemas_in.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""Request schemas for the model-downloader API.
|
||||||
|
|
||||||
|
Each endpoint accepts a small JSON body. Pydantic enforces the shape at
|
||||||
|
the boundary; route handlers operate only on validated values past that.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AvailabilityStatusRequest(BaseModel):
|
||||||
|
"""``POST /api/models-availability-status``.
|
||||||
|
|
||||||
|
Sent by the frontend on each poll. Each entry is ``{model_id: url}``;
|
||||||
|
the URL is the one declared in ``properties.models[i].url`` in the
|
||||||
|
workflow JSON and lets the server compute per-id metadata
|
||||||
|
(``file_size`` + ``is_hf_downloadable``) on the same request.
|
||||||
|
"""
|
||||||
|
models: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadModelsRequest(BaseModel):
|
||||||
|
"""``POST /api/download-models``.
|
||||||
|
|
||||||
|
Same shape as the metadata request — the URL for each model_id.
|
||||||
|
Returns immediately after validation and scheduling.
|
||||||
|
"""
|
||||||
|
models: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CancelDownloadSessionRequest(BaseModel):
|
||||||
|
"""``POST /api/cancel-model-download-session``."""
|
||||||
|
model_id: str
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AvailabilityStatusRequest",
|
||||||
|
"DownloadModelsRequest",
|
||||||
|
"CancelDownloadSessionRequest",
|
||||||
|
]
|
||||||
81
app/model_downloader/api/schemas_out.py
Normal file
81
app/model_downloader/api/schemas_out.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Response schemas for the model-downloader API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
ModelState = Literal["available", "missing", "downloading"]
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadProgress(BaseModel):
|
||||||
|
"""Embedded in a model entry when its state is ``downloading``."""
|
||||||
|
bytes_downloaded: int
|
||||||
|
total_bytes: Optional[int] = None
|
||||||
|
progress: Optional[float] = None # fraction in [0,1]; null until total known
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStatusEntry(BaseModel):
|
||||||
|
"""Everything the UI needs to render one row, in one shot.
|
||||||
|
|
||||||
|
``state`` reflects what the server has on disk + in-flight; ``file_size``
|
||||||
|
and ``is_hf_downloadable`` come from probes (intrinsic; cached).
|
||||||
|
The HF fields are populated for every poll (cached on the server),
|
||||||
|
so license-acceptance flips show up within one poll interval without
|
||||||
|
any frontend cache invalidation.
|
||||||
|
"""
|
||||||
|
state: ModelState
|
||||||
|
progress: Optional[DownloadProgress] = None
|
||||||
|
file_size: Optional[int] = None
|
||||||
|
# HF-only: True iff the server can fetch this URL with current auth
|
||||||
|
# state. False iff gated and lacking access. None for non-HF URLs.
|
||||||
|
is_hf_downloadable: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HfAuthStatus(BaseModel):
|
||||||
|
"""Snapshot of HF login state, embedded in availability response."""
|
||||||
|
token_available: bool
|
||||||
|
eligible: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AvailabilityStatusResponse(BaseModel):
|
||||||
|
models: dict[str, ModelStatusEntry]
|
||||||
|
hf_auth: HfAuthStatus
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadModelsResponse(BaseModel):
|
||||||
|
accepted: bool
|
||||||
|
scheduled: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CancelDownloadSessionResponse(BaseModel):
|
||||||
|
cancelled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class HfAuthTokenStatusResponse(BaseModel):
|
||||||
|
token_available: bool
|
||||||
|
username: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HfAuthLoginStartResponse(BaseModel):
|
||||||
|
authorize_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class HfAuthLogoutResponse(BaseModel):
|
||||||
|
logged_out: bool
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelState",
|
||||||
|
"DownloadProgress",
|
||||||
|
"ModelStatusEntry",
|
||||||
|
"HfAuthStatus",
|
||||||
|
"AvailabilityStatusResponse",
|
||||||
|
"DownloadModelsResponse",
|
||||||
|
"CancelDownloadSessionResponse",
|
||||||
|
"HfAuthTokenStatusResponse",
|
||||||
|
"HfAuthLoginStartResponse",
|
||||||
|
"HfAuthLogoutResponse",
|
||||||
|
]
|
||||||
179
app/model_downloader/download_server.py
Normal file
179
app/model_downloader/download_server.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
"""Process-wide registry of in-flight model downloads.
|
||||||
|
|
||||||
|
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
|
||||||
|
server-side model fetch. Designed to be safe with multiple concurrent
|
||||||
|
clients hitting the API: each model_id has at most one active session,
|
||||||
|
and the API rejects requests that conflict with in-flight downloads.
|
||||||
|
|
||||||
|
Cancellation is cooperative — the download loop checks ``is_active`` on
|
||||||
|
its own session between chunks and raises ``DownloadCancelled`` when the
|
||||||
|
session has been removed from the registry. This avoids the complications
|
||||||
|
of ``Task.cancel()`` from outside the loop while still giving deterministic
|
||||||
|
rollback semantics (the worker is responsible for deleting its own
|
||||||
|
``.tmp`` on the cancel path).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.model_downloader.paths import iter_all_tmp_paths
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadCancelled(Exception):
|
||||||
|
"""Raised by the streaming loop when its session has been removed
|
||||||
|
from the registry (cancellation request) and the worker should roll
|
||||||
|
back its ``.tmp`` file."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DownloadSession:
|
||||||
|
"""One in-flight download.
|
||||||
|
|
||||||
|
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
|
||||||
|
byte arrives and we know whether the response carries a
|
||||||
|
``Content-Length``. ``total_bytes`` mirrors that header when present.
|
||||||
|
"""
|
||||||
|
model_id: str
|
||||||
|
url: str
|
||||||
|
progress: Optional[float] = None
|
||||||
|
bytes_downloaded: int = 0
|
||||||
|
total_bytes: Optional[int] = None
|
||||||
|
# Sequence number used solely as identity for the cancellation check —
|
||||||
|
# so that "cancel + restart" doesn't get confused by stale workers.
|
||||||
|
epoch: int = field(default_factory=lambda: 0)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadServer:
|
||||||
|
"""Singleton registry of active downloads.
|
||||||
|
|
||||||
|
All mutation goes through this object so concurrent route handlers
|
||||||
|
see a consistent view. The ``_lock`` is a plain threading lock
|
||||||
|
because the registry is consulted from both the asyncio event-loop
|
||||||
|
thread (route handlers) and from any worker coroutines spawned to
|
||||||
|
perform downloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._sessions: dict[str, DownloadSession] = {}
|
||||||
|
self._epoch_counter = 0
|
||||||
|
self._orphan_sweep_done = False
|
||||||
|
|
||||||
|
# ----- lifecycle -----
|
||||||
|
|
||||||
|
def sweep_orphan_tmp_files(self) -> None:
|
||||||
|
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
|
||||||
|
|
||||||
|
Deferred off the import path so module load doesn't block on
|
||||||
|
filesystem I/O against potentially-slow mounts. Each route handler
|
||||||
|
that might create a new ``.tmp`` runs this exactly once.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._orphan_sweep_done:
|
||||||
|
return
|
||||||
|
self._orphan_sweep_done = True
|
||||||
|
for path in iter_all_tmp_paths():
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
logging.info("[model_downloader] removed orphan tmp file: %s", path)
|
||||||
|
except OSError as e:
|
||||||
|
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||||
|
|
||||||
|
# ----- queries -----
|
||||||
|
|
||||||
|
def is_downloading(self, model_id: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return model_id in self._sessions
|
||||||
|
|
||||||
|
def get(self, model_id: str) -> Optional[DownloadSession]:
|
||||||
|
with self._lock:
|
||||||
|
return self._sessions.get(model_id)
|
||||||
|
|
||||||
|
def snapshot(self) -> dict[str, DownloadSession]:
|
||||||
|
"""Return a shallow copy of the current sessions map."""
|
||||||
|
with self._lock:
|
||||||
|
return dict(self._sessions)
|
||||||
|
|
||||||
|
# ----- mutations -----
|
||||||
|
|
||||||
|
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
|
||||||
|
"""Atomically register a new session iff none exists for ``model_id``.
|
||||||
|
|
||||||
|
Returns the new session on success, ``None`` if a session is already
|
||||||
|
in flight. Callers must check the return value — the caller is the
|
||||||
|
sole owner of the session it gets back.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if model_id in self._sessions:
|
||||||
|
return None
|
||||||
|
self._epoch_counter += 1
|
||||||
|
session = DownloadSession(
|
||||||
|
model_id=model_id,
|
||||||
|
url=url,
|
||||||
|
epoch=self._epoch_counter,
|
||||||
|
)
|
||||||
|
self._sessions[model_id] = session
|
||||||
|
return session
|
||||||
|
|
||||||
|
def update_progress(
|
||||||
|
self,
|
||||||
|
session: DownloadSession,
|
||||||
|
bytes_downloaded: int,
|
||||||
|
total_bytes: Optional[int],
|
||||||
|
) -> None:
|
||||||
|
"""Update progress on a session. No-op if the session has been
|
||||||
|
removed (cancelled) — caller should check ``is_active`` separately."""
|
||||||
|
with self._lock:
|
||||||
|
current = self._sessions.get(session.model_id)
|
||||||
|
if current is None or current.epoch != session.epoch:
|
||||||
|
return
|
||||||
|
current.bytes_downloaded = bytes_downloaded
|
||||||
|
current.total_bytes = total_bytes
|
||||||
|
if total_bytes and total_bytes > 0:
|
||||||
|
current.progress = min(1.0, bytes_downloaded / total_bytes)
|
||||||
|
|
||||||
|
def is_active(self, session: DownloadSession) -> bool:
|
||||||
|
"""True iff this exact session is still the registered one for
|
||||||
|
its model_id. False after cancellation, after completion, or if
|
||||||
|
another session has replaced it."""
|
||||||
|
with self._lock:
|
||||||
|
current = self._sessions.get(session.model_id)
|
||||||
|
return current is not None and current.epoch == session.epoch
|
||||||
|
|
||||||
|
def finish(self, session: DownloadSession) -> None:
|
||||||
|
"""Remove a completed (or cancelled) session from the registry.
|
||||||
|
|
||||||
|
Safe to call multiple times. Only removes if the epoch matches —
|
||||||
|
we never accidentally evict a *newer* session for the same model_id.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
current = self._sessions.get(session.model_id)
|
||||||
|
if current is not None and current.epoch == session.epoch:
|
||||||
|
del self._sessions[session.model_id]
|
||||||
|
|
||||||
|
def reset_for_tests(self) -> None:
|
||||||
|
"""Clear all sessions and reset the epoch counter. Test-only."""
|
||||||
|
with self._lock:
|
||||||
|
self._sessions.clear()
|
||||||
|
self._epoch_counter = 0
|
||||||
|
|
||||||
|
def cancel(self, model_id: str) -> bool:
|
||||||
|
"""Remove the session registered for ``model_id``.
|
||||||
|
|
||||||
|
Returns True if there was an active session to cancel. The worker
|
||||||
|
will discover the cancellation on its next ``is_active`` check
|
||||||
|
and roll back its ``.tmp`` file.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if model_id in self._sessions:
|
||||||
|
del self._sessions[model_id]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
DOWNLOAD_SERVER = DownloadServer()
|
||||||
216
app/model_downloader/downloader.py
Normal file
216
app/model_downloader/downloader.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
"""Streaming download worker with progress reporting and cancellation.
|
||||||
|
|
||||||
|
Each download writes to ``<final_path>.tmp`` and atomically renames into
|
||||||
|
place on success. Between chunks the worker checks the registry for
|
||||||
|
cancellation (via ``DownloadServer.is_active``) and rolls back its
|
||||||
|
``.tmp`` on cancel or on any error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from app.model_downloader.download_server import (
|
||||||
|
DOWNLOAD_SERVER,
|
||||||
|
DownloadCancelled,
|
||||||
|
DownloadSession,
|
||||||
|
)
|
||||||
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||||
|
from app.model_downloader.hf_url import is_hf_url
|
||||||
|
from app.model_downloader.http_client import get_session, parse_content_length
|
||||||
|
from app.model_downloader.paths import resolve_destination
|
||||||
|
|
||||||
|
|
||||||
|
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
|
||||||
|
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_to_disk(session: DownloadSession) -> str:
|
||||||
|
"""Run a single download to completion or cancellation.
|
||||||
|
|
||||||
|
Returns the final on-disk path on success. Removes the ``.tmp`` and
|
||||||
|
raises on cancellation or failure. The session is finished
|
||||||
|
(removed from the registry) exactly once, here — callers do not
|
||||||
|
need to call ``DOWNLOAD_SERVER.finish`` themselves.
|
||||||
|
"""
|
||||||
|
final_path, tmp_path = resolve_destination(session.model_id, session.epoch)
|
||||||
|
os.makedirs(os.path.dirname(final_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Wipe any stale .tmp from a previous failed attempt before we start —
|
||||||
|
# otherwise a partial body could masquerade as our completed download
|
||||||
|
# when the rename finally happens.
|
||||||
|
_remove_if_exists(tmp_path)
|
||||||
|
|
||||||
|
bytes_seen = 0
|
||||||
|
try:
|
||||||
|
http = await get_session()
|
||||||
|
headers = _auth_headers_for(session.url)
|
||||||
|
logging.info(
|
||||||
|
"[model_downloader] starting GET %s (auth=%s)",
|
||||||
|
session.url, "yes" if "Authorization" in headers else "no",
|
||||||
|
)
|
||||||
|
async with http.get(
|
||||||
|
session.url,
|
||||||
|
allow_redirects=True,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=headers,
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
# Capture a snippet of the response body so 4xx/5xx aren't
|
||||||
|
# opaque in the logs — HF returns JSON or HTML with a
|
||||||
|
# human-readable reason on failures.
|
||||||
|
body_snippet = await _read_short(resp)
|
||||||
|
logging.warning(
|
||||||
|
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
|
||||||
|
session.url, resp.status, str(resp.url), body_snippet,
|
||||||
|
)
|
||||||
|
raise DownloadError(
|
||||||
|
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
|
||||||
|
status=resp.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
total = parse_content_length(resp.headers.get("Content-Length"))
|
||||||
|
DOWNLOAD_SERVER.update_progress(session, 0, total)
|
||||||
|
|
||||||
|
with open(tmp_path, "wb") as f:
|
||||||
|
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
|
||||||
|
# Cancellation check between chunks. Cheap and means
|
||||||
|
# cancellation latency is bounded by one chunk plus
|
||||||
|
# one ``write()`` — typically well under a second
|
||||||
|
# even on slow disks.
|
||||||
|
if not DOWNLOAD_SERVER.is_active(session):
|
||||||
|
raise DownloadCancelled()
|
||||||
|
f.write(chunk)
|
||||||
|
bytes_seen += len(chunk)
|
||||||
|
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
|
||||||
|
|
||||||
|
# Final cancellation check before we promote the .tmp to the real
|
||||||
|
# filename — avoids the awkward case where cancel arrives during
|
||||||
|
# the very last chunk and we'd otherwise commit anyway.
|
||||||
|
if not DOWNLOAD_SERVER.is_active(session):
|
||||||
|
raise DownloadCancelled()
|
||||||
|
|
||||||
|
# Size verification before commit. aiohttp already raises
|
||||||
|
# ClientPayloadError on a truncated Content-Length/chunked body,
|
||||||
|
# but this also catches the HTTP/1.0-style case (no Content-Length
|
||||||
|
# + Connection: close) where a short read can masquerade as a
|
||||||
|
# complete download.
|
||||||
|
if total is not None and bytes_seen != total:
|
||||||
|
raise DownloadError(
|
||||||
|
f"size mismatch for {session.model_id}: "
|
||||||
|
f"got {bytes_seen} of {total} bytes from {session.url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Atomic rename. os.replace is atomic within the same filesystem,
|
||||||
|
# which is guaranteed here because tmp lives alongside final_path.
|
||||||
|
os.replace(tmp_path, final_path)
|
||||||
|
logging.info(
|
||||||
|
"[model_downloader] downloaded %s (%d bytes) from %s",
|
||||||
|
session.model_id, bytes_seen, session.url,
|
||||||
|
)
|
||||||
|
return final_path
|
||||||
|
|
||||||
|
except DownloadCancelled:
|
||||||
|
logging.info("[model_downloader] cancelled: %s", session.model_id)
|
||||||
|
_remove_if_exists(tmp_path)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
"[model_downloader] failed: %s from %s: %s: %s",
|
||||||
|
session.model_id, session.url, type(e).__name__, e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
_remove_if_exists(tmp_path)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# In all terminal states (success / cancel / error) drop the
|
||||||
|
# session from the registry. Idempotent — only removes if we're
|
||||||
|
# still the live epoch for this model_id.
|
||||||
|
DOWNLOAD_SERVER.finish(session)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadError(Exception):
|
||||||
|
"""Network / protocol error during a download."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, status: Optional[int] = None) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
|
||||||
|
"""Read up to ``limit`` bytes of a response body for logging.
|
||||||
|
|
||||||
|
Used to surface the JSON/HTML reason from an HF non-2xx response in
|
||||||
|
server logs instead of just the status code. Best-effort: any
|
||||||
|
error here is swallowed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw = await resp.content.read(limit)
|
||||||
|
return raw.decode("utf-8", errors="replace").strip()
|
||||||
|
except Exception:
|
||||||
|
return "<unreadable>"
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers_for(url: str) -> dict[str, str]:
|
||||||
|
"""Return any auth headers we should add to the GET for ``url``.
|
||||||
|
|
||||||
|
For HuggingFace URLs we inject the user's OAuth access token as a
|
||||||
|
Bearer header — this is HF's documented way to access gated repos
|
||||||
|
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
|
||||||
|
other host we send no extra headers; allowlisted public files
|
||||||
|
don't need them and we don't want to leak tokens to other hosts.
|
||||||
|
"""
|
||||||
|
if not is_hf_url(url):
|
||||||
|
return {}
|
||||||
|
tok = HF_AUTH_STORE.get_token_sync()
|
||||||
|
if tok is None or not tok.access_token:
|
||||||
|
return {}
|
||||||
|
return {"Authorization": f"Bearer {tok.access_token}"}
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_if_exists(path: str) -> None:
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except OSError as e:
|
||||||
|
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
|
||||||
|
"""Run a list of sessions one after the other.
|
||||||
|
|
||||||
|
Each session is independent: a failure or cancellation of one does
|
||||||
|
not abort the rest. Cancellations are observable via the registry
|
||||||
|
*before* a given download starts, so a session that's been
|
||||||
|
pre-cancelled (cancel before the worker reached it) just gets skipped.
|
||||||
|
"""
|
||||||
|
for session in sessions:
|
||||||
|
# If the session got cancelled before its turn, skip without
|
||||||
|
# touching disk. This is what makes the per-request "sequential
|
||||||
|
# but cancellable" semantic work.
|
||||||
|
if not DOWNLOAD_SERVER.is_active(session):
|
||||||
|
DOWNLOAD_SERVER.finish(session)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await stream_to_disk(session)
|
||||||
|
except DownloadCancelled:
|
||||||
|
# Already logged + tmp removed inside stream_to_disk.
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
# stream_to_disk already logged. Continue with the rest of the batch.
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
|
||||||
|
"""Kick off ``run_batch_sequential`` on the running event loop.
|
||||||
|
|
||||||
|
Returned task is fire-and-forget; the API handler returns immediately
|
||||||
|
after scheduling and clients observe progress via the polling endpoints.
|
||||||
|
"""
|
||||||
|
return asyncio.create_task(run_batch_sequential(sessions))
|
||||||
245
app/model_downloader/gated_detection.py
Normal file
245
app/model_downloader/gated_detection.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
"""Per-URL probes for the unified availability endpoint.
|
||||||
|
|
||||||
|
Three cached/derived facts per URL:
|
||||||
|
|
||||||
|
- ``is_gated`` intrinsic to the model; cached forever once known.
|
||||||
|
Determined by ``auth_check(repo_id, token=None)``:
|
||||||
|
``GatedRepoError`` → True, success → False.
|
||||||
|
|
||||||
|
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
|
||||||
|
call. For non-gated URLs this is trivially True
|
||||||
|
(no HF call needed). For gated URLs we run
|
||||||
|
``auth_check`` with the stored token each call.
|
||||||
|
|
||||||
|
- ``file_size`` intrinsic to the file. Cached forever once
|
||||||
|
determined (including ``None`` on transient
|
||||||
|
failure — we don't retry). We only attempt the
|
||||||
|
HEAD when we already know the URL is downloadable
|
||||||
|
to us; that way a failed-because-gated probe
|
||||||
|
never lands as a cached ``None``.
|
||||||
|
|
||||||
|
Caches are per-process, in-memory; small, no eviction needed for the
|
||||||
|
workflow-scale (~tens of URLs). Concurrent calls for the same URL
|
||||||
|
deduplicate via per-URL ``asyncio.Lock``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.errors import (
|
||||||
|
GatedRepoError,
|
||||||
|
HfHubHTTPError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.model_downloader.allowlist import is_url_allowed
|
||||||
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||||
|
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
||||||
|
from app.model_downloader.http_client import get_session, parse_content_length
|
||||||
|
|
||||||
|
|
||||||
|
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProbeResult:
|
||||||
|
file_size: Optional[int]
|
||||||
|
is_hf_downloadable: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
# --- caches -------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
|
||||||
|
# model — never changes for a given URL.
|
||||||
|
_is_gated_cache: dict[str, bool] = {}
|
||||||
|
|
||||||
|
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
|
||||||
|
# was attempted and produced no answer. **Only populated when we knew
|
||||||
|
# the URL was downloadable to us at probe time** — so gated-without-
|
||||||
|
# access never lands a ``None`` here that we'd be stuck with after login.
|
||||||
|
_file_size_cache: dict[str, Optional[int]] = {}
|
||||||
|
|
||||||
|
# Per-URL locks for single-flight probes — when multiple polls arrive
|
||||||
|
# in the same tick for the same URL, exactly one of them runs the HF
|
||||||
|
# call and the others wait on the result.
|
||||||
|
_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _lock_for(url: str) -> asyncio.Lock:
|
||||||
|
lock = _locks.get(url)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
_locks[url] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
def clear_caches_for_tests() -> None:
|
||||||
|
"""Test-only: drop everything."""
|
||||||
|
_is_gated_cache.clear()
|
||||||
|
_file_size_cache.clear()
|
||||||
|
_locks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# --- public entrypoint --------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def probe_url(url: str) -> ProbeResult:
|
||||||
|
"""Return downloadability + size for one URL, using caches where safe."""
|
||||||
|
if not is_url_allowed(url):
|
||||||
|
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||||
|
if not is_hf_url(url):
|
||||||
|
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
|
||||||
|
# Size we still cache so we don't HEAD on every poll.
|
||||||
|
size = await _get_or_probe_size(url, token=None)
|
||||||
|
return ProbeResult(file_size=size, is_hf_downloadable=None)
|
||||||
|
|
||||||
|
repo_id = repo_id_from_url(url)
|
||||||
|
if repo_id is None:
|
||||||
|
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||||
|
|
||||||
|
# Determine intrinsic gating once.
|
||||||
|
gated = await _resolve_is_gated(url, repo_id)
|
||||||
|
if gated is None:
|
||||||
|
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||||
|
|
||||||
|
# Compute current-token downloadability per call.
|
||||||
|
tok = await HF_AUTH_STORE.get_valid_token()
|
||||||
|
token_str: Optional[str] = tok.access_token if tok else None
|
||||||
|
if not gated:
|
||||||
|
is_hf_downloadable: Optional[bool] = True
|
||||||
|
else:
|
||||||
|
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
|
||||||
|
|
||||||
|
if is_hf_downloadable is True:
|
||||||
|
size = await _get_or_probe_size(url, token=token_str)
|
||||||
|
else:
|
||||||
|
# Skip the HEAD entirely — would 401 and we'd be stuck with
|
||||||
|
# cached None that survives a later login.
|
||||||
|
size = None
|
||||||
|
|
||||||
|
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
|
||||||
|
|
||||||
|
|
||||||
|
# --- gated/auth probes --------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
|
||||||
|
"""Decide once whether ``repo_id`` is a gated repo."""
|
||||||
|
cached = _is_gated_cache.get(url)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
async with _lock_for(url):
|
||||||
|
cached = _is_gated_cache.get(url)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
# Probe anonymously (token=None) on purpose: an unauthenticated
|
||||||
|
# auth_check is what makes HF raise GatedRepoError for gated repos.
|
||||||
|
# With a token, a gated-but-accepted repo would succeed and look
|
||||||
|
# ungated.
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(_auth_check_sync, repo_id, None)
|
||||||
|
_is_gated_cache[url] = False
|
||||||
|
return False
|
||||||
|
except GatedRepoError:
|
||||||
|
_is_gated_cache[url] = True
|
||||||
|
return True
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
# Repo doesn't exist publicly. Treat as gated — we can't
|
||||||
|
# serve it without auth, and an authenticated check might
|
||||||
|
# still succeed if it's a private repo the user can see.
|
||||||
|
_is_gated_cache[url] = True
|
||||||
|
return True
|
||||||
|
except (HfHubHTTPError, Exception) as e:
|
||||||
|
logging.debug(
|
||||||
|
"[hf_auth] is_gated probe failed for %s (will retry): %s",
|
||||||
|
repo_id, e,
|
||||||
|
)
|
||||||
|
return None # don't cache; retry next call
|
||||||
|
|
||||||
|
|
||||||
|
async def _auth_check_with_token(
|
||||||
|
repo_id: str, token: Optional[str]
|
||||||
|
) -> Optional[bool]:
|
||||||
|
"""Auth-check with the supplied token. True/False/None per outcome."""
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(_auth_check_sync, repo_id, token)
|
||||||
|
return True
|
||||||
|
except GatedRepoError:
|
||||||
|
return False
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
return False
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
# 401/403 covers org-SSO-required, revoked tokens, and similar —
|
||||||
|
# all of which mean "can't fetch right now" from the user's POV.
|
||||||
|
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||||
|
if status in (401, 403):
|
||||||
|
return False
|
||||||
|
logging.debug(
|
||||||
|
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
|
||||||
|
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
|
||||||
|
HfApi().auth_check(repo_id, token=token)
|
||||||
|
|
||||||
|
|
||||||
|
# --- size probe ---------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
|
||||||
|
"""Return the cached size or HEAD the URL once and cache the result."""
|
||||||
|
if url in _file_size_cache:
|
||||||
|
return _file_size_cache[url]
|
||||||
|
|
||||||
|
async with _lock_for(url):
|
||||||
|
if url in _file_size_cache:
|
||||||
|
return _file_size_cache[url]
|
||||||
|
size = await _probe_size_once(url, token=token)
|
||||||
|
_file_size_cache[url] = size
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
|
||||||
|
"""HEAD the URL and return the file size in bytes, or None on failure.
|
||||||
|
|
||||||
|
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
|
||||||
|
The real file size lives in the non-standard ``X-Linked-Size`` header
|
||||||
|
on that 302 response (``Content-Length`` is the redirect-body length).
|
||||||
|
Disabling redirect-follow lets us read either header on the same
|
||||||
|
response:
|
||||||
|
|
||||||
|
- LFS files: 302 + ``X-Linked-Size``
|
||||||
|
- Small/non-LFS files: 200 + ``Content-Length``
|
||||||
|
"""
|
||||||
|
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||||
|
try:
|
||||||
|
session = await get_session()
|
||||||
|
async with session.head(
|
||||||
|
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
|
||||||
|
) as resp:
|
||||||
|
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
|
||||||
|
if linked is not None:
|
||||||
|
return linked
|
||||||
|
if resp.status == 200:
|
||||||
|
return parse_content_length(resp.headers.get("Content-Length"))
|
||||||
|
return None
|
||||||
|
except (aiohttp.ClientError, TimeoutError, OSError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Backward-compat shim so consumers that still import the old name keep
|
||||||
|
# building during the refactor; can be removed once routes are updated.
|
||||||
|
MetadataProbeResult = ProbeResult
|
||||||
121
app/model_downloader/hf_auth/auth_store.py
Normal file
121
app/model_downloader/hf_auth/auth_store.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
"""In-memory token cache with lazy disk persistence + refresh.
|
||||||
|
|
||||||
|
Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask
|
||||||
|
``get_valid_token()``; the store transparently refreshes from disk
|
||||||
|
on first use, refreshes via the OAuth refresh_token if the cached
|
||||||
|
access_token is expired, and returns ``None`` if neither path works.
|
||||||
|
|
||||||
|
The refresh path imports ``oauth.refresh_access_token`` lazily to
|
||||||
|
avoid an import cycle (oauth needs the store to save tokens it
|
||||||
|
acquires).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.model_downloader.hf_auth.token_store import (
|
||||||
|
Token,
|
||||||
|
delete_token,
|
||||||
|
load_token,
|
||||||
|
save_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HfAuthStore:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._token: Optional[Token] = None
|
||||||
|
self._loaded_from_disk = False
|
||||||
|
|
||||||
|
def _ensure_loaded(self) -> None:
|
||||||
|
"""Read the disk token into memory on first access."""
|
||||||
|
if self._loaded_from_disk:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
if self._loaded_from_disk:
|
||||||
|
return
|
||||||
|
self._token = load_token()
|
||||||
|
self._loaded_from_disk = True
|
||||||
|
|
||||||
|
def has_token(self) -> bool:
|
||||||
|
"""Cheap check: is there any token in memory?
|
||||||
|
|
||||||
|
Does not attempt refresh; an expired-but-refreshable token still
|
||||||
|
counts as "logged in" from the user's perspective.
|
||||||
|
"""
|
||||||
|
self._ensure_loaded()
|
||||||
|
return self._token is not None
|
||||||
|
|
||||||
|
def _store_token_locked(self, token: Token) -> None:
|
||||||
|
"""Set the in-memory token and persist it to disk.
|
||||||
|
|
||||||
|
Caller must already hold ``self._lock``. Keeping the disk write inside
|
||||||
|
the lock means memory and disk flip together — a concurrent ``clear()``
|
||||||
|
or refresh can't interleave between them.
|
||||||
|
"""
|
||||||
|
self._token = token
|
||||||
|
self._loaded_from_disk = True
|
||||||
|
save_token(token)
|
||||||
|
|
||||||
|
def set_token(self, token: Token) -> None:
|
||||||
|
"""Replace the in-memory token and persist to disk (atomically)."""
|
||||||
|
with self._lock:
|
||||||
|
self._store_token_locked(token)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Forget the token in memory and on disk (logout)."""
|
||||||
|
with self._lock:
|
||||||
|
self._token = None
|
||||||
|
self._loaded_from_disk = True
|
||||||
|
delete_token()
|
||||||
|
|
||||||
|
def get_token_sync(self) -> Optional[Token]:
|
||||||
|
"""Return the cached token without refreshing.
|
||||||
|
|
||||||
|
Sync callers (e.g. constructing an Authorization header in a
|
||||||
|
non-async path) use this. They accept an expired token over
|
||||||
|
``None``; HF will simply return 401 and the caller can decide
|
||||||
|
what to do.
|
||||||
|
"""
|
||||||
|
self._ensure_loaded()
|
||||||
|
return self._token
|
||||||
|
|
||||||
|
async def get_valid_token(self) -> Optional[Token]:
|
||||||
|
"""Return a fresh token, refreshing via OAuth if necessary.
|
||||||
|
|
||||||
|
Returns ``None`` if there's no cached token at all, or if the
|
||||||
|
cached token is expired and refresh failed. Callers should
|
||||||
|
treat that as "user is not logged in".
|
||||||
|
"""
|
||||||
|
self._ensure_loaded()
|
||||||
|
with self._lock:
|
||||||
|
tok = self._token
|
||||||
|
if tok is None:
|
||||||
|
return None
|
||||||
|
if tok.is_valid():
|
||||||
|
return tok
|
||||||
|
if not tok.refresh_token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Lazy import to avoid the oauth ↔ store import cycle.
|
||||||
|
from app.model_downloader.hf_auth.oauth import refresh_access_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
refreshed = await refresh_access_token(tok.refresh_token)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("[hf_auth] token refresh failed: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# If a logout (clear) or another update replaced the token while we
|
||||||
|
# were awaiting the refresh, don't resurrect the old session.
|
||||||
|
if self._token is not tok:
|
||||||
|
return None
|
||||||
|
self._store_token_locked(refreshed)
|
||||||
|
return refreshed
|
||||||
|
|
||||||
|
|
||||||
|
HF_AUTH_STORE = HfAuthStore()
|
||||||
55
app/model_downloader/hf_auth/eligibility.py
Normal file
55
app/model_downloader/hf_auth/eligibility.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Whether this deployment is allowed to do interactive HF OAuth.
|
||||||
|
|
||||||
|
We only let the server hold a HuggingFace token under a strict trust
|
||||||
|
assumption: this is a *single tenant local* install. Concretely:
|
||||||
|
|
||||||
|
- The server is bound to a loopback address. SSH tunneling /
|
||||||
|
reverse-proxies can defeat this, but it's the strongest signal
|
||||||
|
we have without an authentication system.
|
||||||
|
- ``--multi-user`` is off. A shared token used implicitly by multiple
|
||||||
|
declared users would be a footgun — one user's gated downloads
|
||||||
|
would silently authenticate as another.
|
||||||
|
|
||||||
|
Anything else and the frontend hides the HF login UI entirely; gated
|
||||||
|
models continue to show the "acquire it manually" message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
|
def _is_loopback(host: str | None) -> bool:
|
||||||
|
"""Duplicates ``server.is_loopback`` (small, no shared module yet).
|
||||||
|
|
||||||
|
Resolves a host or IP literal to whether it lives on the loopback
|
||||||
|
interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for
|
||||||
|
``0.0.0.0`` / ``::`` because those are bind-all wildcards, not
|
||||||
|
loopback.
|
||||||
|
"""
|
||||||
|
if host is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return ipaddress.ip_address(host).is_loopback
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loopback = False
|
||||||
|
for family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
try:
|
||||||
|
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||||
|
for _family, _, _, _, sockaddr in r:
|
||||||
|
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||||
|
return loopback
|
||||||
|
loopback = True
|
||||||
|
except socket.gaierror:
|
||||||
|
pass
|
||||||
|
return loopback
|
||||||
|
|
||||||
|
|
||||||
|
def is_hf_auth_eligible() -> bool:
|
||||||
|
"""True iff this deployment may surface the HF OAuth flow."""
|
||||||
|
return _is_loopback(args.listen) and not args.multi_user
|
||||||
301
app/model_downloader/hf_auth/oauth.py
Normal file
301
app/model_downloader/hf_auth/oauth.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
|
||||||
|
|
||||||
|
Wired so that ``POST /api/hf-auth-login-start`` can:
|
||||||
|
1. Generate state + PKCE verifier/challenge in this process.
|
||||||
|
2. Spin up a short-lived loopback HTTP server at port 41954 to
|
||||||
|
receive the redirect callback from HF.
|
||||||
|
3. Return the ``authorize_url`` for the frontend to open in a new tab.
|
||||||
|
|
||||||
|
After the user grants consent on huggingface.co, HF redirects to the
|
||||||
|
local callback URL with ``code`` and ``state``. The callback server
|
||||||
|
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
|
||||||
|
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
|
||||||
|
itself down.
|
||||||
|
|
||||||
|
Before this can be exercised end-to-end a maintainer must register a
|
||||||
|
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
|
||||||
|
below. See the comment above the constant for the exact steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||||
|
from app.model_downloader.hf_auth.token_store import Token
|
||||||
|
from app.model_downloader.http_client import get_session
|
||||||
|
|
||||||
|
|
||||||
|
# --- HF OAuth app registration -------------------------------------------- #
|
||||||
|
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
|
||||||
|
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
|
||||||
|
# under a Comfy-Org-controlled HF account and substitute its client_id here.
|
||||||
|
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
|
||||||
|
# ("HuggingFace OAuth app setup" section). Short version:
|
||||||
|
# 1. huggingface.co → Settings → Connected Apps → "Create app"
|
||||||
|
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
|
||||||
|
# ``gated-repos`` (Repository Access). Leave everything else off.
|
||||||
|
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
|
||||||
|
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
|
||||||
|
# change ``CALLBACK_PORT``.
|
||||||
|
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
|
||||||
|
# The client_id is not a secret (it travels through the user's browser in
|
||||||
|
# plaintext); HF's "Public app" type means there's no client secret to
|
||||||
|
# manage — PKCE replaces it.
|
||||||
|
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
|
||||||
|
|
||||||
|
CALLBACK_HOST = "127.0.0.1"
|
||||||
|
CALLBACK_PORT = 41954
|
||||||
|
CALLBACK_PATH = "/api/auth/huggingface/callback"
|
||||||
|
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
|
||||||
|
|
||||||
|
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
|
||||||
|
TOKEN_URL = "https://huggingface.co/oauth/token"
|
||||||
|
_TOKEN_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||||
|
# Minimal scope set for the feature:
|
||||||
|
# - openid : required by HF when the app uses OIDC at all
|
||||||
|
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
|
||||||
|
# settings UI; cosmetic but expected
|
||||||
|
# - gated-repos : grants the token enough to call ``auth_check`` and
|
||||||
|
# download files from public gated repos the user has
|
||||||
|
# accepted the license for. The wider ``read-repos`` scope
|
||||||
|
# would also work (it includes ``gated-repos``) but it
|
||||||
|
# additionally grants private-repo read access, which we
|
||||||
|
# don't need and which makes the consent screen scarier
|
||||||
|
# for the user.
|
||||||
|
SCOPE = "openid profile gated-repos"
|
||||||
|
|
||||||
|
# Maximum time the callback server stays up waiting for the user to
|
||||||
|
# complete consent on huggingface.co. Past this, the port closes and
|
||||||
|
# the user has to click "Log in" again.
|
||||||
|
CALLBACK_TIMEOUT_SECS = 300
|
||||||
|
|
||||||
|
|
||||||
|
# Process-wide lock so two simultaneous /api/hf-auth-login-start
|
||||||
|
# requests don't fight over port CALLBACK_PORT.
|
||||||
|
_OAUTH_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthInProgressError(Exception):
|
||||||
|
"""Another OAuth attempt is already running."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackError(Exception):
|
||||||
|
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
|
||||||
|
|
||||||
|
|
||||||
|
# --- PKCE primitives ------------------------------------------------------ #
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pkce() -> tuple[str, str, str]:
|
||||||
|
"""Return ``(verifier, challenge, state)``.
|
||||||
|
|
||||||
|
Verifier never leaves this process. Challenge and state travel
|
||||||
|
through the user's browser. State is checked on the callback to
|
||||||
|
prevent a malicious cross-origin redirect from injecting a token.
|
||||||
|
"""
|
||||||
|
verifier = secrets.token_urlsafe(64)
|
||||||
|
challenge = (
|
||||||
|
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
||||||
|
.rstrip(b"=")
|
||||||
|
.decode("ascii")
|
||||||
|
)
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
return verifier, challenge, state
|
||||||
|
|
||||||
|
|
||||||
|
def _build_authorize_url(challenge: str, state: str) -> str:
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"client_id": HF_CLIENT_ID,
|
||||||
|
"redirect_uri": REDIRECT_URI,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": SCOPE,
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
}
|
||||||
|
return f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Token exchange ------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def _exchange_code(code: str, verifier: str) -> Token:
|
||||||
|
"""Trade the authorization code for an access+refresh token pair."""
|
||||||
|
data = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": REDIRECT_URI,
|
||||||
|
"client_id": HF_CLIENT_ID,
|
||||||
|
"code_verifier": verifier,
|
||||||
|
}
|
||||||
|
session = await get_session()
|
||||||
|
async with session.post(TOKEN_URL, data=data, timeout=_TOKEN_REQUEST_TIMEOUT) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = await resp.json()
|
||||||
|
return Token(
|
||||||
|
access_token=body["access_token"],
|
||||||
|
refresh_token=body.get("refresh_token"),
|
||||||
|
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
||||||
|
scope=body.get("scope", SCOPE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_access_token(refresh_token: str) -> Token:
|
||||||
|
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
|
||||||
|
data = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"client_id": HF_CLIENT_ID,
|
||||||
|
}
|
||||||
|
session = await get_session()
|
||||||
|
async with session.post(TOKEN_URL, data=data, timeout=_TOKEN_REQUEST_TIMEOUT) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
body = await resp.json()
|
||||||
|
return Token(
|
||||||
|
access_token=body["access_token"],
|
||||||
|
# If HF doesn't rotate refresh tokens, keep using the existing one.
|
||||||
|
refresh_token=body.get("refresh_token", refresh_token),
|
||||||
|
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
||||||
|
scope=body.get("scope", SCOPE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Callback server ------------------------------------------------------ #
|
||||||
|
|
||||||
|
|
||||||
|
async def start_login_flow() -> str:
|
||||||
|
"""Begin one OAuth attempt: spawn the callback server, return the URL.
|
||||||
|
|
||||||
|
Returns the URL the frontend should open in a new tab. Raises
|
||||||
|
``OAuthInProgressError`` if another attempt is already running.
|
||||||
|
The callback server runs in the background until the user
|
||||||
|
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
|
||||||
|
either way the lock + port are released afterward.
|
||||||
|
"""
|
||||||
|
if not _OAUTH_LOCK.acquire(blocking=False):
|
||||||
|
raise OAuthInProgressError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
verifier, challenge, state = _make_pkce()
|
||||||
|
authorize_url = _build_authorize_url(challenge, state)
|
||||||
|
ready: asyncio.Future[None] = asyncio.get_event_loop().create_future()
|
||||||
|
except BaseException:
|
||||||
|
# Failed before handing the lock to the callback-server task: release it
|
||||||
|
# here. (Once the task is spawned, it owns releasing the lock.)
|
||||||
|
_OAUTH_LOCK.release()
|
||||||
|
raise
|
||||||
|
|
||||||
|
asyncio.create_task(_run_callback_server(verifier, state, ready))
|
||||||
|
# Don't return the URL until the callback server is actually bound and
|
||||||
|
# listening — otherwise HF could redirect to a port nothing is serving and
|
||||||
|
# the login would silently dead-end. ``ready`` raises if the bind failed.
|
||||||
|
await ready
|
||||||
|
return authorize_url
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_callback_server(
|
||||||
|
verifier: str, expected_state: str, ready: "asyncio.Future[None]"
|
||||||
|
) -> None:
|
||||||
|
"""Listen for HF's redirect once, capture the token, then shut down.
|
||||||
|
|
||||||
|
Signals ``ready`` once the port is bound (or with an exception if the bind
|
||||||
|
fails), so ``start_login_flow`` only hands back a URL on a live server.
|
||||||
|
"""
|
||||||
|
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
if request.query.get("state") != expected_state:
|
||||||
|
return web.Response(status=400, text="state mismatch")
|
||||||
|
err = request.query.get("error")
|
||||||
|
if err:
|
||||||
|
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
|
||||||
|
return web.Response(status=400, text=f"OAuth error: {err}")
|
||||||
|
code = request.query.get("code")
|
||||||
|
if not code:
|
||||||
|
return web.Response(status=400, text="missing code")
|
||||||
|
tok = await _exchange_code(code, verifier)
|
||||||
|
if not received.done():
|
||||||
|
received.set_result(tok)
|
||||||
|
return web.Response(
|
||||||
|
content_type="text/html",
|
||||||
|
text=(
|
||||||
|
"<html><body style='font-family:sans-serif;padding:40px'>"
|
||||||
|
"<h2>HuggingFace login successful</h2>"
|
||||||
|
"<p>You can close this tab and return to ComfyUI.</p>"
|
||||||
|
"</body></html>"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if not received.done():
|
||||||
|
received.set_exception(exc)
|
||||||
|
return web.Response(status=500, text=str(exc))
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get(CALLBACK_PATH, handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
try:
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
|
||||||
|
await site.start()
|
||||||
|
except Exception as e:
|
||||||
|
# Couldn't bind the callback port (commonly already in use). Tell the
|
||||||
|
# waiting start_login_flow via ``ready`` so it surfaces a clear error
|
||||||
|
# instead of returning a dead URL, and release the lock for next time.
|
||||||
|
logging.warning("[hf_auth] could not start callback server: %s", e)
|
||||||
|
if not ready.done():
|
||||||
|
ready.set_exception(
|
||||||
|
OAuthCallbackError(f"could not bind callback port {CALLBACK_PORT}: {e}")
|
||||||
|
)
|
||||||
|
_OAUTH_LOCK.release()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Bound and listening — now it's safe for start_login_flow to return the URL.
|
||||||
|
if not ready.done():
|
||||||
|
ready.set_result(None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
|
||||||
|
return
|
||||||
|
except OAuthCallbackError as e:
|
||||||
|
logging.warning("[hf_auth] OAuth callback error: %s", e)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
HF_AUTH_STORE.set_token(token)
|
||||||
|
logging.info("[hf_auth] OAuth login complete")
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
if _OAUTH_LOCK.locked():
|
||||||
|
_OAUTH_LOCK.release()
|
||||||
|
|
||||||
|
|
||||||
|
def is_login_in_progress() -> bool:
|
||||||
|
"""True iff a callback server is currently bound + waiting."""
|
||||||
|
return _OAUTH_LOCK.locked()
|
||||||
|
|
||||||
|
|
||||||
|
# Re-export for callers that only want the URL builder (e.g. tests).
|
||||||
|
__all__ = [
|
||||||
|
"start_login_flow",
|
||||||
|
"refresh_access_token",
|
||||||
|
"is_login_in_progress",
|
||||||
|
"OAuthInProgressError",
|
||||||
|
"CALLBACK_TIMEOUT_SECS",
|
||||||
|
]
|
||||||
94
app/model_downloader/hf_auth/token_store.py
Normal file
94
app/model_downloader/hf_auth/token_store.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""On-disk persistence for the HuggingFace OAuth token.
|
||||||
|
|
||||||
|
The token shape mirrors what HF returns on the token exchange: an
|
||||||
|
``access_token``, an optional ``refresh_token``, the absolute epoch at
|
||||||
|
which the access token expires, and the granted scope. We persist
|
||||||
|
this so logging in once survives ComfyUI restarts under the internal
|
||||||
|
``__hf_auth`` system-user directory; the file is mode ``0600`` so only
|
||||||
|
the OS user can read it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
|
# Treat a token as expired this many seconds before its server-reported
|
||||||
|
# ``expires_at`` so we don't try to use a token mid-request only for it
|
||||||
|
# to flip stale between auth_check and the actual GET.
|
||||||
|
EXPIRY_BUFFER_SECS = 60
|
||||||
|
|
||||||
|
TOKEN_FILENAME = "hf_auth_token.json"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Token:
|
||||||
|
"""One OAuth token + the metadata we need to use it."""
|
||||||
|
access_token: str
|
||||||
|
refresh_token: Optional[str]
|
||||||
|
expires_at: float # absolute epoch seconds
|
||||||
|
scope: str = ""
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""True iff we can use this token right now."""
|
||||||
|
return (
|
||||||
|
bool(self.access_token)
|
||||||
|
and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _token_dir() -> str:
|
||||||
|
return folder_paths.get_system_user_directory("hf_auth")
|
||||||
|
|
||||||
|
|
||||||
|
def _token_path() -> str:
|
||||||
|
return os.path.join(_token_dir(), TOKEN_FILENAME)
|
||||||
|
|
||||||
|
|
||||||
|
def load_token() -> Optional[Token]:
|
||||||
|
"""Read the persisted token, returning ``None`` if absent or corrupt."""
|
||||||
|
path = _token_path()
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return Token(**data)
|
||||||
|
except (OSError, json.JSONDecodeError, TypeError) as e:
|
||||||
|
logging.warning("[hf_auth] could not load token at %s: %s", path, e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_token(token: Token) -> None:
|
||||||
|
"""Atomically write the token with 0600 permissions."""
|
||||||
|
path = _token_path()
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
tmp = path + ".tmp"
|
||||||
|
fd = os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||||
|
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(asdict(token), f)
|
||||||
|
os.replace(tmp, path)
|
||||||
|
try:
|
||||||
|
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
|
||||||
|
except OSError as e:
|
||||||
|
# On Windows / weird filesystems chmod may be a no-op; not fatal.
|
||||||
|
logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_token() -> None:
|
||||||
|
"""Remove the persisted token; no-op if it doesn't exist."""
|
||||||
|
path = _token_path()
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except OSError as e:
|
||||||
|
logging.warning("[hf_auth] could not remove token at %s: %s", path, e)
|
||||||
41
app/model_downloader/hf_url.py
Normal file
41
app/model_downloader/hf_url.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""Parsers for the ``huggingface.co`` URL shape we accept in workflows.
|
||||||
|
|
||||||
|
The download API accepts URLs of the form
|
||||||
|
``https://huggingface.co/<org>/<repo>/resolve/<rev>/<path/to/file>``.
|
||||||
|
We need to recover ``<org>/<repo>`` (the *repo_id*) from such URLs for
|
||||||
|
``huggingface_hub`` API calls (notably ``HfApi.auth_check``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
_HF_HOST = "huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
|
def is_hf_url(url: str) -> bool:
|
||||||
|
"""Cheap host check — does this URL point at huggingface.co?"""
|
||||||
|
try:
|
||||||
|
return urlparse(url).hostname == _HF_HOST
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def repo_id_from_url(url: str) -> Optional[str]:
|
||||||
|
"""Extract ``<org>/<repo>`` from an HF model file URL.
|
||||||
|
|
||||||
|
Returns ``None`` if the URL isn't on huggingface.co or doesn't look
|
||||||
|
like a model-file URL. The expected shape is
|
||||||
|
``/<org>/<repo>/resolve/<rev>/<path>`` — anything else
|
||||||
|
(datasets, spaces, /tree/, /blob/, …) we treat as out of scope here.
|
||||||
|
"""
|
||||||
|
if not is_hf_url(url):
|
||||||
|
return None
|
||||||
|
parts = urlparse(url).path.lstrip("/").split("/")
|
||||||
|
if len(parts) < 4 or parts[2] != "resolve":
|
||||||
|
return None
|
||||||
|
org, repo = parts[0], parts[1]
|
||||||
|
if not org or not repo:
|
||||||
|
return None
|
||||||
|
return f"{org}/{repo}"
|
||||||
63
app/model_downloader/http_client.py
Normal file
63
app/model_downloader/http_client.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""Lazy module-level aiohttp ClientSession.
|
||||||
|
|
||||||
|
A single shared session means TLS handshakes are reused across HEAD probes
|
||||||
|
and the subsequent GETs to the same host (HuggingFace is the dominant
|
||||||
|
case), which is a noticeable speedup on cold connections.
|
||||||
|
|
||||||
|
We deliberately don't close the session at process exit — aiohttp's
|
||||||
|
warning about unclosed sessions is benign at shutdown, and adding atexit
|
||||||
|
plumbing buys nothing because the OS reclaims the sockets anyway. The
|
||||||
|
session lifetime is the lifetime of the Python process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import ssl
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import certifi
|
||||||
|
|
||||||
|
|
||||||
|
# Larger per-host pool than aiohttp's default (=100 total / =0 per host)
|
||||||
|
# so concurrent gated probes + a download to the same host don't queue.
|
||||||
|
_CONNECTOR_LIMIT_PER_HOST = 8
|
||||||
|
|
||||||
|
_session: Optional[aiohttp.ClientSession] = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def ssl_context() -> ssl.SSLContext:
|
||||||
|
"""TLS context pinned to certifi's CA bundle.
|
||||||
|
aiohttp's default context uses the OS trust store, which isn't wired up
|
||||||
|
on some Python installs (python.org macOS, slim containers) — there TLS
|
||||||
|
to huggingface.co fails with CERTIFICATE_VERIFY_FAILED.
|
||||||
|
"""
|
||||||
|
return ssl.create_default_context(cafile=certifi.where())
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> aiohttp.ClientSession:
|
||||||
|
"""Return the shared session, creating it on first call."""
|
||||||
|
global _session
|
||||||
|
if _session is not None and not _session.closed:
|
||||||
|
return _session
|
||||||
|
async with _lock:
|
||||||
|
if _session is None or _session.closed:
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit_per_host=_CONNECTOR_LIMIT_PER_HOST,
|
||||||
|
ssl=ssl_context(),
|
||||||
|
)
|
||||||
|
_session = aiohttp.ClientSession(connector=connector)
|
||||||
|
return _session
|
||||||
|
|
||||||
|
|
||||||
|
def parse_content_length(value: Optional[str]) -> Optional[int]:
|
||||||
|
"""Parse a byte-count header value, or None if absent/malformed/negative."""
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
n = int(value)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return n if n >= 0 else None
|
||||||
111
app/model_downloader/paths.py
Normal file
111
app/model_downloader/paths.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
"""Path resolution for model downloads.
|
||||||
|
|
||||||
|
Model identifiers used across the download API are *relative destination
|
||||||
|
paths* of the form ``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``).
|
||||||
|
This module turns one of those identifiers into an absolute on-disk path
|
||||||
|
under one of ComfyUI's registered model folders, while rejecting unknown
|
||||||
|
folders, path traversal, and other ill-formed inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
|
# Constrain components so a model_id can never escape its target directory.
|
||||||
|
# - directory: a single path segment of safe chars
|
||||||
|
# - filename: a single path segment of safe chars, must end with a model extension
|
||||||
|
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||||
|
|
||||||
|
# Destination filename must name a model file (same set as the URL allowlist),
|
||||||
|
# so a download can't land as e.g. ``foo.txt`` that ComfyUI won't recognise.
|
||||||
|
_MODEL_EXTENSIONS = (".safetensors", ".sft", ".ckpt", ".pth", ".pt")
|
||||||
|
|
||||||
|
# Distinctive temp suffix so the startup orphan-sweep only removes files THIS
|
||||||
|
# subsystem created — never unrelated ``*.tmp`` files in the model dirs.
|
||||||
|
_TMP_SUFFIX = ".comfy-download.tmp"
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelId(ValueError):
|
||||||
|
"""Raised when a model_id is syntactically invalid or refers to an
|
||||||
|
unknown model folder."""
|
||||||
|
|
||||||
|
|
||||||
|
def parse_model_id(model_id: str) -> Tuple[str, str]:
|
||||||
|
"""Split ``<directory>/<filename>`` and validate both components.
|
||||||
|
|
||||||
|
Returns ``(directory, filename)``. Raises ``InvalidModelId`` on
|
||||||
|
malformed input. Does NOT touch the filesystem.
|
||||||
|
"""
|
||||||
|
if not isinstance(model_id, str) or "/" not in model_id:
|
||||||
|
raise InvalidModelId(f"model_id must be '<directory>/<filename>', got {model_id!r}")
|
||||||
|
directory, _, filename = model_id.partition("/")
|
||||||
|
if "/" in filename or not directory or not filename:
|
||||||
|
raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}")
|
||||||
|
if not _SEGMENT_RE.match(directory):
|
||||||
|
raise InvalidModelId(f"invalid directory segment {directory!r}")
|
||||||
|
if not _SEGMENT_RE.match(filename):
|
||||||
|
raise InvalidModelId(f"invalid filename segment {filename!r}")
|
||||||
|
if not filename.endswith(_MODEL_EXTENSIONS):
|
||||||
|
raise InvalidModelId(
|
||||||
|
f"filename must end with a model extension {_MODEL_EXTENSIONS}, got {filename!r}"
|
||||||
|
)
|
||||||
|
if directory not in folder_paths.folder_names_and_paths:
|
||||||
|
raise InvalidModelId(f"unknown model folder {directory!r}")
|
||||||
|
return directory, filename
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_existing(model_id: str) -> Optional[str]:
|
||||||
|
"""Return the absolute path of an installed model, or None if missing.
|
||||||
|
|
||||||
|
Honours ``extra_model_paths.yaml`` transparently via
|
||||||
|
``folder_paths.get_full_path``.
|
||||||
|
"""
|
||||||
|
directory, filename = parse_model_id(model_id)
|
||||||
|
return folder_paths.get_full_path(directory, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_destination(model_id: str, epoch: int = 0) -> Tuple[str, str]:
|
||||||
|
"""Return ``(final_path, tmp_path)`` for a download.
|
||||||
|
|
||||||
|
Downloads land at the first registered path for the model's directory
|
||||||
|
(the "primary" location). The temp sibling is the write target, atomically
|
||||||
|
renamed onto ``final_path`` on success.
|
||||||
|
|
||||||
|
``tmp_path`` embeds the session ``epoch`` so a cancel+retry of the same
|
||||||
|
model never shares a temp path between the old (cancelling) worker and the
|
||||||
|
new attempt — otherwise the old worker's rollback could delete the new
|
||||||
|
worker's in-progress file. The distinctive suffix scopes the orphan sweep.
|
||||||
|
"""
|
||||||
|
directory, filename = parse_model_id(model_id)
|
||||||
|
roots = folder_paths.get_folder_paths(directory)
|
||||||
|
if not roots:
|
||||||
|
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
|
||||||
|
root = roots[0]
|
||||||
|
final_path = os.path.join(root, filename)
|
||||||
|
tmp_path = f"{final_path}.{epoch}{_TMP_SUFFIX}"
|
||||||
|
return final_path, tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
def iter_all_tmp_paths():
|
||||||
|
"""Yield this subsystem's temp files under every registered model folder.
|
||||||
|
|
||||||
|
Matches only our distinctive ``_TMP_SUFFIX`` (not every ``*.tmp``) so the
|
||||||
|
startup orphan-sweep can't delete temp files created by other tools.
|
||||||
|
"""
|
||||||
|
seen_roots: set[str] = set()
|
||||||
|
for directory in folder_paths.folder_names_and_paths.keys():
|
||||||
|
for root in folder_paths.get_folder_paths(directory):
|
||||||
|
if root in seen_roots or not os.path.isdir(root):
|
||||||
|
continue
|
||||||
|
seen_roots.add(root)
|
||||||
|
try:
|
||||||
|
for entry in os.scandir(root):
|
||||||
|
if entry.is_file() and entry.name.endswith(_TMP_SUFFIX):
|
||||||
|
yield entry.path
|
||||||
|
except OSError:
|
||||||
|
# Folder might be unreadable / missing on certain mounts —
|
||||||
|
# not fatal, just skip it.
|
||||||
|
continue
|
||||||
@ -98,12 +98,24 @@ def _parse_cli_feature_flags() -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
# Default server capabilities
|
# Default server capabilities
|
||||||
|
def _hf_auth_eligible_at_startup() -> bool:
|
||||||
|
"""Snapshot eligibility once at feature-flag init time.
|
||||||
|
|
||||||
|
Imports lazily because the flags module loads very early in the
|
||||||
|
server boot sequence — earlier than the model_downloader package.
|
||||||
|
"""
|
||||||
|
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
|
||||||
|
return is_hf_auth_eligible()
|
||||||
|
|
||||||
|
|
||||||
_CORE_FEATURE_FLAGS: dict[str, Any] = {
|
_CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
"node_replacements": True,
|
"node_replacements": True,
|
||||||
"assets": args.enable_assets,
|
"assets": args.enable_assets,
|
||||||
|
"server_side_model_downloads": True,
|
||||||
|
"hf_auth_eligible": _hf_auth_eligible_at_startup(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# CLI-provided flags cannot overwrite core flags
|
# CLI-provided flags cannot overwrite core flags
|
||||||
|
|||||||
1273
docs/server-side-model-downloads-handover.html
Normal file
1273
docs/server-side-model-downloads-handover.html
Normal file
File diff suppressed because it is too large
Load Diff
341
openapi.yaml
341
openapi.yaml
@ -188,6 +188,49 @@ components:
|
|||||||
- id
|
- id
|
||||||
- updated_at
|
- updated_at
|
||||||
type: object
|
type: object
|
||||||
|
AvailabilityStatusRequest:
|
||||||
|
description: |
|
||||||
|
Models to query — each entry is `model_id → URL`. The URL lets
|
||||||
|
the server compute file_size + is_hf_downloadable on the same
|
||||||
|
request, eliminating the need for a separate metadata endpoint.
|
||||||
|
properties:
|
||||||
|
models:
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
description: model_id → URL declared in the workflow.
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- models
|
||||||
|
type: object
|
||||||
|
AvailabilityStatusResponse:
|
||||||
|
description: Per-model state + metadata + HF auth snapshot.
|
||||||
|
properties:
|
||||||
|
hf_auth:
|
||||||
|
$ref: '#/components/schemas/HfAuthStatus'
|
||||||
|
models:
|
||||||
|
additionalProperties:
|
||||||
|
$ref: '#/components/schemas/ModelStatusEntry'
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- models
|
||||||
|
- hf_auth
|
||||||
|
type: object
|
||||||
|
CancelDownloadSessionRequest:
|
||||||
|
description: Request to cancel an in-flight download for a given model_id.
|
||||||
|
properties:
|
||||||
|
model_id:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- model_id
|
||||||
|
type: object
|
||||||
|
CancelDownloadSessionResponse:
|
||||||
|
description: Result of a cancellation request.
|
||||||
|
properties:
|
||||||
|
cancelled:
|
||||||
|
type: boolean
|
||||||
|
required:
|
||||||
|
- cancelled
|
||||||
|
type: object
|
||||||
CreateWorkflowRequest:
|
CreateWorkflowRequest:
|
||||||
description: Request body for creating a new saved workflow.
|
description: Request body for creating a new saved workflow.
|
||||||
properties:
|
properties:
|
||||||
@ -230,6 +273,51 @@ components:
|
|||||||
- base_version
|
- base_version
|
||||||
- workflow_json
|
- workflow_json
|
||||||
type: object
|
type: object
|
||||||
|
DownloadModelsRequest:
|
||||||
|
description: Map of model_id → URL of files to fetch into the model folders.
|
||||||
|
properties:
|
||||||
|
models:
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
description: model_id → URL of models to download.
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- models
|
||||||
|
type: object
|
||||||
|
DownloadModelsResponse:
|
||||||
|
description: Acknowledgement that downloads have been scheduled.
|
||||||
|
properties:
|
||||||
|
accepted:
|
||||||
|
description: Always true; the request was scheduled.
|
||||||
|
type: boolean
|
||||||
|
scheduled:
|
||||||
|
description: The list of model_ids whose downloads are now in-flight.
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
type: array
|
||||||
|
required:
|
||||||
|
- accepted
|
||||||
|
- scheduled
|
||||||
|
type: object
|
||||||
|
DownloadProgress:
|
||||||
|
description: In-flight download progress; embedded in ModelStatusEntry.
|
||||||
|
properties:
|
||||||
|
bytes_downloaded:
|
||||||
|
format: int64
|
||||||
|
type: integer
|
||||||
|
progress:
|
||||||
|
description: Fraction in [0,1]; null until total_bytes is known.
|
||||||
|
format: float
|
||||||
|
nullable: true
|
||||||
|
type: number
|
||||||
|
total_bytes:
|
||||||
|
description: Content-Length when supplied by the source.
|
||||||
|
format: int64
|
||||||
|
nullable: true
|
||||||
|
type: integer
|
||||||
|
required:
|
||||||
|
- bytes_downloaded
|
||||||
|
type: object
|
||||||
ErrorResponse:
|
ErrorResponse:
|
||||||
description: Standard error response with a machine-readable code and human-readable message.
|
description: Standard error response with a machine-readable code and human-readable message.
|
||||||
properties:
|
properties:
|
||||||
@ -394,6 +482,46 @@ components:
|
|||||||
- name
|
- name
|
||||||
- info
|
- info
|
||||||
type: object
|
type: object
|
||||||
|
HfAuthLoginStartResponse:
|
||||||
|
description: URL the frontend should open in a new tab to complete login.
|
||||||
|
properties:
|
||||||
|
authorize_url:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- authorize_url
|
||||||
|
type: object
|
||||||
|
HfAuthLogoutResponse:
|
||||||
|
description: Result of the logout call (always logged_out = true).
|
||||||
|
properties:
|
||||||
|
logged_out:
|
||||||
|
type: boolean
|
||||||
|
required:
|
||||||
|
- logged_out
|
||||||
|
type: object
|
||||||
|
HfAuthStatus:
|
||||||
|
description: Inline snapshot of the server's HuggingFace OAuth state.
|
||||||
|
properties:
|
||||||
|
eligible:
|
||||||
|
description: True iff this deployment can surface interactive HF login.
|
||||||
|
type: boolean
|
||||||
|
token_available:
|
||||||
|
description: True iff a token (possibly expired but refreshable) is stored.
|
||||||
|
type: boolean
|
||||||
|
required:
|
||||||
|
- token_available
|
||||||
|
- eligible
|
||||||
|
type: object
|
||||||
|
HfAuthTokenStatusResponse:
|
||||||
|
description: Whether the server holds an HF OAuth token + resolved username.
|
||||||
|
properties:
|
||||||
|
token_available:
|
||||||
|
type: boolean
|
||||||
|
username:
|
||||||
|
nullable: true
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- token_available
|
||||||
|
type: object
|
||||||
HistoryDetailEntry:
|
HistoryDetailEntry:
|
||||||
description: History entry with full prompt data
|
description: History entry with full prompt data
|
||||||
properties:
|
properties:
|
||||||
@ -798,6 +926,41 @@ components:
|
|||||||
- name
|
- name
|
||||||
- folders
|
- folders
|
||||||
type: object
|
type: object
|
||||||
|
ModelStatusEntry:
|
||||||
|
description: Everything the UI needs to render one row of the model.
|
||||||
|
properties:
|
||||||
|
file_size:
|
||||||
|
description: Bytes, when known. Cached server-side per URL.
|
||||||
|
format: int64
|
||||||
|
nullable: true
|
||||||
|
type: integer
|
||||||
|
is_hf_downloadable:
|
||||||
|
description: |
|
||||||
|
HuggingFace-only signal. True if the server can fetch this
|
||||||
|
URL with its current auth state (public, or gated-with-access).
|
||||||
|
False if gated and lacking access. Null for non-HF URLs and
|
||||||
|
for HF URLs whose probe failed entirely.
|
||||||
|
nullable: true
|
||||||
|
type: boolean
|
||||||
|
progress:
|
||||||
|
type: object
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/DownloadProgress'
|
||||||
|
description: Present when `state == downloading`.
|
||||||
|
nullable: true
|
||||||
|
state:
|
||||||
|
description: |
|
||||||
|
`available` — file is on disk.
|
||||||
|
`missing` — not on disk and no download in flight.
|
||||||
|
`downloading` — server is currently fetching the file.
|
||||||
|
enum:
|
||||||
|
- available
|
||||||
|
- missing
|
||||||
|
- downloading
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- state
|
||||||
|
type: object
|
||||||
NodeInfo:
|
NodeInfo:
|
||||||
description: Metadata describing a single ComfyUI node type and its inputs/outputs.
|
description: Metadata describing a single ComfyUI node type and its inputs/outputs.
|
||||||
properties:
|
properties:
|
||||||
@ -2350,6 +2513,72 @@ paths:
|
|||||||
summary: Get tag histogram for filtered assets
|
summary: Get tag histogram for filtered assets
|
||||||
tags:
|
tags:
|
||||||
- file
|
- file
|
||||||
|
/api/cancel-model-download-session:
|
||||||
|
post:
|
||||||
|
description: |
|
||||||
|
Cancels the download session for the given model_id. The worker
|
||||||
|
observes the cancellation between chunks, removes its partial `.tmp`
|
||||||
|
file, and exits without writing the destination path.
|
||||||
|
operationId: postCancelModelDownloadSession
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/CancelDownloadSessionRequest'
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/CancelDownloadSessionResponse'
|
||||||
|
description: Session cancelled.
|
||||||
|
"404":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ErrorResponse'
|
||||||
|
description: No active download for that model_id.
|
||||||
|
summary: Cancel an in-flight server-side model download
|
||||||
|
tags:
|
||||||
|
- model
|
||||||
|
/api/download-models:
|
||||||
|
post:
|
||||||
|
description: |
|
||||||
|
Schedules downloads for every model_id in the request map. Returns
|
||||||
|
immediately after validation; progress is observed via
|
||||||
|
`/api/models-availability-status`. Fails atomically if any model
|
||||||
|
is already on disk, already downloading, gated, or has a URL that
|
||||||
|
is not on the server's allowlist (HuggingFace, Civitai, localhost).
|
||||||
|
operationId: postDownloadModels
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/DownloadModelsRequest'
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
"202":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/DownloadModelsResponse'
|
||||||
|
description: Downloads accepted and scheduled.
|
||||||
|
"400":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ErrorResponse'
|
||||||
|
description: One of the requested models is invalid, gated, or has a non-allowed URL.
|
||||||
|
"409":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ErrorResponse'
|
||||||
|
description: One of the requested models is already on disk or downloading.
|
||||||
|
summary: Start a server-side download of one or more models
|
||||||
|
tags:
|
||||||
|
- model
|
||||||
/api/embeddings:
|
/api/embeddings:
|
||||||
get:
|
get:
|
||||||
description: Returns the list of text-encoder embeddings available on disk.
|
description: Returns the list of text-encoder embeddings available on disk.
|
||||||
@ -2655,6 +2884,74 @@ paths:
|
|||||||
summary: Get a specific subgraph blueprint
|
summary: Get a specific subgraph blueprint
|
||||||
tags:
|
tags:
|
||||||
- workflow
|
- workflow
|
||||||
|
/api/hf-auth-login-start:
|
||||||
|
post:
|
||||||
|
description: |
|
||||||
|
Spawns a short-lived loopback callback server (port 41954) and
|
||||||
|
returns the URL the frontend should open in a new tab. After the
|
||||||
|
user grants consent, HF redirects back to the callback URL with
|
||||||
|
an authorization code; the server exchanges that for a token and
|
||||||
|
persists it. Subsequent `/api/hf-auth-token-status` calls will
|
||||||
|
return `token_available: true`. Rejected with 403 if the
|
||||||
|
deployment is not eligible (not loopback or in --multi-user mode);
|
||||||
|
409 if another login attempt is already in progress.
|
||||||
|
operationId: postHfAuthLoginStart
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/HfAuthLoginStartResponse'
|
||||||
|
description: Login flow started; `authorize_url` is ready.
|
||||||
|
"403":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ErrorResponse'
|
||||||
|
description: Deployment is not eligible for interactive HF login.
|
||||||
|
"409":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ErrorResponse'
|
||||||
|
description: Another login attempt is already in progress.
|
||||||
|
summary: Begin a HuggingFace OAuth login flow
|
||||||
|
tags:
|
||||||
|
- model
|
||||||
|
/api/hf-auth-logout:
|
||||||
|
post:
|
||||||
|
description: Clears the in-memory cache and removes the on-disk token file.
|
||||||
|
operationId: postHfAuthLogout
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/HfAuthLogoutResponse'
|
||||||
|
description: Logged out (idempotent — succeeds even if no token was held).
|
||||||
|
summary: Drop the stored HuggingFace OAuth token
|
||||||
|
tags:
|
||||||
|
- model
|
||||||
|
/api/hf-auth-token-status:
|
||||||
|
get:
|
||||||
|
description: |
|
||||||
|
Returns `token_available: true` when the server has a token
|
||||||
|
in memory (or on disk) for HuggingFace, irrespective of whether
|
||||||
|
the access_token is currently fresh — an expired one with a
|
||||||
|
refresh_token still counts as "logged in" because we'll refresh
|
||||||
|
transparently on next use. If a username is resolvable via
|
||||||
|
`HfApi.whoami` we return that too, for the settings UI.
|
||||||
|
operationId: getHfAuthTokenStatus
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/HfAuthTokenStatusResponse'
|
||||||
|
description: Token status.
|
||||||
|
summary: Whether the server holds a usable HuggingFace OAuth token
|
||||||
|
tags:
|
||||||
|
- model
|
||||||
/api/history:
|
/api/history:
|
||||||
post:
|
post:
|
||||||
deprecated: true
|
deprecated: true
|
||||||
@ -3157,6 +3454,48 @@ paths:
|
|||||||
summary: Cancel multiple jobs
|
summary: Cancel multiple jobs
|
||||||
tags:
|
tags:
|
||||||
- workflow
|
- workflow
|
||||||
|
/api/models-availability-status:
|
||||||
|
post:
|
||||||
|
description: |
|
||||||
|
Given a map of `{model_id: url}` (model_id is
|
||||||
|
`<directory>/<filename>`), returns per-id state plus the
|
||||||
|
metadata the UI needs to render the row:
|
||||||
|
|
||||||
|
- `state` — one of `available` / `missing` / `downloading`
|
||||||
|
- `progress` — embedded when `state == downloading`
|
||||||
|
- `file_size` — bytes (when known)
|
||||||
|
- `is_hf_downloadable` — for HF URLs only: true if the
|
||||||
|
server can currently fetch the file with its stored auth
|
||||||
|
state, false if gated and lacking access, null otherwise
|
||||||
|
|
||||||
|
Designed for 1 Hz polling. `file_size` and the intrinsic
|
||||||
|
"is this model gated" check are cached server-side per URL;
|
||||||
|
`is_hf_downloadable` is recomputed per call so license
|
||||||
|
acceptance and login/logout transitions show up within one
|
||||||
|
poll interval without any client-side cache plumbing.
|
||||||
|
operationId: postModelsAvailabilityStatus
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/AvailabilityStatusRequest'
|
||||||
|
required: true
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/AvailabilityStatusResponse'
|
||||||
|
description: Per-model status and metadata.
|
||||||
|
"400":
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ErrorResponse'
|
||||||
|
description: Malformed request body.
|
||||||
|
summary: Unified per-model status + metadata for the polling UI
|
||||||
|
tags:
|
||||||
|
- model
|
||||||
/api/node_replacements:
|
/api/node_replacements:
|
||||||
get:
|
get:
|
||||||
description: |
|
description: |
|
||||||
@ -5103,3 +5442,5 @@ tags:
|
|||||||
name: queue
|
name: queue
|
||||||
- description: Job lifecycle queries
|
- description: Job lifecycle queries
|
||||||
name: job
|
name: job
|
||||||
|
- description: Server-side model availability and downloads
|
||||||
|
name: model
|
||||||
|
|||||||
@ -9,6 +9,7 @@ numpy>=1.25.0
|
|||||||
einops
|
einops
|
||||||
transformers>=4.50.3
|
transformers>=4.50.3
|
||||||
tokenizers>=0.13.3
|
tokenizers>=0.13.3
|
||||||
|
huggingface_hub
|
||||||
sentencepiece
|
sentencepiece
|
||||||
safetensors>=0.4.2
|
safetensors>=0.4.2
|
||||||
aiohttp>=3.11.8
|
aiohttp>=3.11.8
|
||||||
|
|||||||
@ -47,6 +47,7 @@ from app.assets.seeder import asset_seeder
|
|||||||
from app.assets.api.routes import register_assets_routes
|
from app.assets.api.routes import register_assets_routes
|
||||||
from app.assets.services.ingest import register_file_in_place
|
from app.assets.services.ingest import register_file_in_place
|
||||||
from app.assets.services.asset_management import resolve_hash_to_path
|
from app.assets.services.asset_management import resolve_hash_to_path
|
||||||
|
from app.model_downloader.api.routes import register_routes as register_model_downloader_routes
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@ -257,6 +258,7 @@ class PromptServer():
|
|||||||
else:
|
else:
|
||||||
register_assets_routes(self.app)
|
register_assets_routes(self.app)
|
||||||
asset_seeder.disable()
|
asset_seeder.disable()
|
||||||
|
register_model_downloader_routes(self.app)
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
|
|||||||
708
tests-unit/app_test/hf_auth_test.py
Normal file
708
tests-unit/app_test/hf_auth_test.py
Normal file
@ -0,0 +1,708 @@
|
|||||||
|
"""Unit tests for the HuggingFace auth subsystem.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- token store: save/load roundtrip, chmod 0600, atomic write, delete
|
||||||
|
- eligibility under various CLI-arg combinations
|
||||||
|
- URL parsing (huggingface.co host detection + repo_id extraction)
|
||||||
|
- HF-aware gated_detection.probe_url (mocked auth_check)
|
||||||
|
- HF auth routes (token status, login start with eligibility gate, logout)
|
||||||
|
- PKCE primitives + authorize URL shape
|
||||||
|
|
||||||
|
The OAuth callback server itself isn't exercised end-to-end here — that
|
||||||
|
requires a real HF server. We test the components (state checking,
|
||||||
|
URL building, code-exchange request shape) instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from app.model_downloader.api.routes import register_routes
|
||||||
|
from app.model_downloader.hf_auth import oauth
|
||||||
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore
|
||||||
|
from app.model_downloader.hf_auth.token_store import (
|
||||||
|
EXPIRY_BUFFER_SECS,
|
||||||
|
Token,
|
||||||
|
delete_token,
|
||||||
|
load_token,
|
||||||
|
save_token,
|
||||||
|
)
|
||||||
|
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Fixtures
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_user_dir(tmp_path):
|
||||||
|
"""Redirect ``folder_paths.get_user_directory`` so the token file
|
||||||
|
lands in an isolated tmp_path instead of the real user dir."""
|
||||||
|
user_dir = tmp_path / "user"
|
||||||
|
user_dir.mkdir()
|
||||||
|
with patch("folder_paths.get_user_directory", return_value=str(user_dir)):
|
||||||
|
yield user_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _token_file_path(user_dir) -> str:
|
||||||
|
return os.path.join(user_dir, "__hf_auth", "hf_auth_token.json")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_auth_store():
|
||||||
|
"""Wipe singleton state between tests: auth + probe caches."""
|
||||||
|
from app.model_downloader import gated_detection
|
||||||
|
|
||||||
|
HF_AUTH_STORE._token = None
|
||||||
|
HF_AUTH_STORE._loaded_from_disk = False
|
||||||
|
gated_detection.clear_caches_for_tests()
|
||||||
|
yield HF_AUTH_STORE
|
||||||
|
HF_AUTH_STORE._token = None
|
||||||
|
HF_AUTH_STORE._loaded_from_disk = False
|
||||||
|
gated_detection.clear_caches_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(patched_user_dir, fresh_auth_store):
|
||||||
|
app = web.Application()
|
||||||
|
register_routes(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# URL parsing
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_hf_url_recognises_huggingface_co():
|
||||||
|
assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors")
|
||||||
|
assert is_hf_url("https://huggingface.co/abc")
|
||||||
|
assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors")
|
||||||
|
assert not is_hf_url("https://civitai.com/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_extracts_org_and_repo():
|
||||||
|
url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors"
|
||||||
|
assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_handles_nested_path():
|
||||||
|
url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors"
|
||||||
|
assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_returns_none_for_non_hf():
|
||||||
|
assert repo_id_from_url("https://civitai.com/x.safetensors") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_returns_none_for_non_resolve_paths():
|
||||||
|
assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None
|
||||||
|
assert repo_id_from_url("https://huggingface.co/org") is None
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Token store
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_roundtrip(patched_user_dir):
|
||||||
|
tok = Token(
|
||||||
|
access_token="hf_abc",
|
||||||
|
refresh_token="rf_def",
|
||||||
|
expires_at=9999999999.0,
|
||||||
|
scope="openid profile",
|
||||||
|
)
|
||||||
|
save_token(tok)
|
||||||
|
loaded = load_token()
|
||||||
|
assert loaded == tok
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_writes_0600(patched_user_dir):
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
|
||||||
|
save_token(tok)
|
||||||
|
path = _token_file_path(patched_user_dir)
|
||||||
|
mode = stat.S_IMODE(os.stat(path).st_mode)
|
||||||
|
# On Windows we silently no-op chmod; allow either the intended
|
||||||
|
# mode or whatever umask the OS gave us.
|
||||||
|
if os.name == "posix":
|
||||||
|
assert mode == 0o600
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_delete_removes_file(patched_user_dir):
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
|
||||||
|
save_token(tok)
|
||||||
|
delete_token()
|
||||||
|
path = _token_file_path(patched_user_dir)
|
||||||
|
assert not os.path.exists(path)
|
||||||
|
# Idempotent: second delete is fine.
|
||||||
|
delete_token()
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_load_returns_none_for_missing_file(patched_user_dir):
|
||||||
|
assert load_token() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir):
|
||||||
|
path = _token_file_path(patched_user_dir)
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write("not json {")
|
||||||
|
assert load_token() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_is_valid_uses_buffer(patched_user_dir):
|
||||||
|
import time
|
||||||
|
|
||||||
|
fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
|
||||||
|
nearly_expired = Token(
|
||||||
|
access_token="x",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=time.time() + EXPIRY_BUFFER_SECS - 1,
|
||||||
|
)
|
||||||
|
assert fresh.is_valid()
|
||||||
|
assert not nearly_expired.is_valid()
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_is_valid_rejects_empty_access_token():
|
||||||
|
import time
|
||||||
|
|
||||||
|
tok = Token(access_token="", refresh_token=None, expires_at=time.time() + 3600)
|
||||||
|
assert not tok.is_valid()
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_is_valid_rejects_at_exact_buffer_boundary():
|
||||||
|
import time
|
||||||
|
|
||||||
|
tok = Token(
|
||||||
|
access_token="x",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=time.time() + EXPIRY_BUFFER_SECS,
|
||||||
|
)
|
||||||
|
assert not tok.is_valid()
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Auth store
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_loads_lazily(patched_user_dir):
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||||
|
save_token(tok)
|
||||||
|
store = HfAuthStore()
|
||||||
|
assert store.has_token()
|
||||||
|
assert store.get_token_sync() == tok
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_set_persists(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||||
|
store.set_token(tok)
|
||||||
|
# Token is on disk now — a fresh store sees it.
|
||||||
|
assert HfAuthStore().get_token_sync() == tok
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||||
|
store.set_token(tok)
|
||||||
|
store.clear()
|
||||||
|
assert not store.has_token()
|
||||||
|
assert HfAuthStore().get_token_sync() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_has_token_true_when_expired_but_refreshable(patched_user_dir):
|
||||||
|
import time
|
||||||
|
|
||||||
|
store = HfAuthStore()
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
assert store.has_token()
|
||||||
|
assert not expired.is_valid()
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_get_token_sync_returns_expired_without_refresh(patched_user_dir):
|
||||||
|
import time
|
||||||
|
|
||||||
|
store = HfAuthStore()
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
assert store.get_token_sync() == expired
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_store_get_valid_returns_none_when_expired_without_refresh(
|
||||||
|
patched_user_dir,
|
||||||
|
):
|
||||||
|
import time
|
||||||
|
|
||||||
|
store = HfAuthStore()
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||||
|
new=AsyncMock(),
|
||||||
|
) as refresh_mock:
|
||||||
|
result = await store.get_valid_token()
|
||||||
|
assert result is None
|
||||||
|
refresh_mock.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
|
||||||
|
store.set_token(tok)
|
||||||
|
fetched = await store.get_valid_token()
|
||||||
|
assert fetched == tok
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
refreshed = Token(
|
||||||
|
access_token="new",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() + 3600,
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||||
|
new=AsyncMock(return_value=refreshed),
|
||||||
|
):
|
||||||
|
result = await store.get_valid_token()
|
||||||
|
assert result == refreshed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_store_get_valid_token_does_not_resurrect_after_logout(
|
||||||
|
patched_user_dir,
|
||||||
|
):
|
||||||
|
"""A logout landing *during* an in-flight refresh must not be undone by
|
||||||
|
the refresh writing the token back (the resurrection race)."""
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
expired = Token(
|
||||||
|
access_token="old", refresh_token="rf", expires_at=time.time() - 100
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
refreshed = Token(
|
||||||
|
access_token="new", refresh_token="rf", expires_at=time.time() + 3600
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_refresh(_refresh_token):
|
||||||
|
# Simulate the user clicking "Log out" while the refresh is in flight.
|
||||||
|
store.clear()
|
||||||
|
return refreshed
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||||
|
new=fake_refresh,
|
||||||
|
):
|
||||||
|
result = await store.get_valid_token()
|
||||||
|
|
||||||
|
# The refresh result is discarded — logout wins, in memory and on disk.
|
||||||
|
assert result is None
|
||||||
|
assert not store.has_token()
|
||||||
|
assert load_token() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("HF down")),
|
||||||
|
):
|
||||||
|
result = await store.get_valid_token()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Eligibility
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"listen,multi_user,expected",
|
||||||
|
[
|
||||||
|
("127.0.0.1", False, True),
|
||||||
|
("127.0.0.1", True, False), # multi-user disables it
|
||||||
|
("0.0.0.0", False, False), # bind-all is not loopback
|
||||||
|
("0.0.0.0", True, False),
|
||||||
|
("192.168.1.5", False, False), # LAN address
|
||||||
|
("::1", False, True), # IPv6 loopback
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_eligibility(listen, multi_user, expected, monkeypatch):
|
||||||
|
from app.model_downloader.hf_auth import eligibility
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
monkeypatch.setattr(args, "listen", listen)
|
||||||
|
monkeypatch.setattr(args, "multi_user", multi_user)
|
||||||
|
assert eligibility.is_hf_auth_eligible() is expected
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# gated_detection HF probe
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_url_hf_public(fresh_auth_store):
|
||||||
|
"""auth_check succeeds with no token → is_hf_downloadable = True."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||||
|
with patch("app.model_downloader.gated_detection._auth_check_sync"), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=1024),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is True
|
||||||
|
assert result.file_size == 1024
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_url_hf_gated_no_access(fresh_auth_store):
|
||||||
|
"""auth_check raises GatedRepoError → is_hf_downloadable = False."""
|
||||||
|
from huggingface_hub.errors import GatedRepoError
|
||||||
|
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
|
||||||
|
fake_response = MagicMock(status_code=403)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
side_effect=GatedRepoError("gated", response=fake_response),
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_url_non_hf_skips_auth_check():
|
||||||
|
"""Non-HF URLs never call auth_check; is_hf_downloadable stays None."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://civitai.com/api/download/models/1.safetensors"
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
) as mocked, patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=2048),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is None
|
||||||
|
assert result.file_size == 2048
|
||||||
|
mocked.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_gated_cached_across_calls(fresh_auth_store):
|
||||||
|
"""Intrinsic ``is_gated`` should be determined exactly once per URL.
|
||||||
|
|
||||||
|
Subsequent ``probe_url`` calls for the same URL must not re-issue
|
||||||
|
the null-token auth_check — that's the whole point of the cache."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync"
|
||||||
|
) as mocked, patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=1024),
|
||||||
|
):
|
||||||
|
await probe_url(url)
|
||||||
|
await probe_url(url)
|
||||||
|
await probe_url(url)
|
||||||
|
# Three probe_url calls × public-only-needs-1-auth_check = 1 call total.
|
||||||
|
assert mocked.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_size_cached_across_calls(fresh_auth_store):
|
||||||
|
"""Once a successful HEAD lands, subsequent calls don't re-HEAD."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync"
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=2048),
|
||||||
|
) as size_probe:
|
||||||
|
r1 = await probe_url(url)
|
||||||
|
r2 = await probe_url(url)
|
||||||
|
assert r1.file_size == 2048
|
||||||
|
assert r2.file_size == 2048
|
||||||
|
assert size_probe.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store):
|
||||||
|
"""When ``is_hf_downloadable`` is False we must NOT HEAD the URL —
|
||||||
|
otherwise a 401-due-to-gating would land as a cached ``None`` that
|
||||||
|
survives a later successful login."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
from huggingface_hub.errors import GatedRepoError
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
|
||||||
|
fake_resp = MagicMock(status_code=403)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
side_effect=GatedRepoError("gated", response=fake_resp),
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
) as size_probe:
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is False
|
||||||
|
assert result.file_size is None
|
||||||
|
assert size_probe.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir):
|
||||||
|
"""For a gated URL, auth_check runs twice: once with token=None to
|
||||||
|
determine the intrinsic ``is_gated`` flag (cached forever), and once
|
||||||
|
with the stored access_token to determine ``is_hf_downloadable`` for
|
||||||
|
the current user."""
|
||||||
|
from app.model_downloader import gated_detection
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
from huggingface_hub.errors import GatedRepoError
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
gated_detection.clear_caches_for_tests()
|
||||||
|
fresh_auth_store.set_token(Token(
|
||||||
|
access_token="hf_test_token",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=9999999999.0,
|
||||||
|
))
|
||||||
|
url = "https://huggingface.co/private/repo/resolve/main/x.safetensors"
|
||||||
|
|
||||||
|
fake_resp = MagicMock(status_code=403)
|
||||||
|
|
||||||
|
def fake_auth_check(repo_id, token):
|
||||||
|
# Null-token call → repo is gated. Subsequent call with the real
|
||||||
|
# token succeeds (user has access).
|
||||||
|
if token is None:
|
||||||
|
raise GatedRepoError("gated", response=fake_resp)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
side_effect=fake_auth_check,
|
||||||
|
) as mocked, patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
|
||||||
|
# is_hf_downloadable should be True (token-authed call succeeded).
|
||||||
|
assert result.is_hf_downloadable is True
|
||||||
|
# Two calls: (repo_id, None) then (repo_id, <token>).
|
||||||
|
assert mocked.call_count == 2
|
||||||
|
assert mocked.call_args_list[0].args == ("private/repo", None)
|
||||||
|
assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token")
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# OAuth primitives
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_pkce_returns_distinct_high_entropy_values():
|
||||||
|
verifier1, challenge1, state1 = oauth._make_pkce()
|
||||||
|
verifier2, challenge2, state2 = oauth._make_pkce()
|
||||||
|
assert verifier1 != verifier2
|
||||||
|
assert challenge1 != challenge2
|
||||||
|
assert state1 != state2
|
||||||
|
# Verifier should be at least 43 chars per PKCE spec.
|
||||||
|
assert len(verifier1) >= 43
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_authorize_url_includes_pkce_and_state():
|
||||||
|
url = oauth._build_authorize_url("challenge123", "state456")
|
||||||
|
assert url.startswith(oauth.AUTHORIZE_URL)
|
||||||
|
assert "client_id=" + oauth.HF_CLIENT_ID in url
|
||||||
|
assert "code_challenge=challenge123" in url
|
||||||
|
assert "code_challenge_method=S256" in url
|
||||||
|
assert "state=state456" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Routes
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_token_status_empty(aiohttp_client, app):
|
||||||
|
"""No token set → token_available=false, username=null."""
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/api/hf-auth-token-status")
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert data == {"token_available": False, "username": None}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_token_status_with_token(
|
||||||
|
aiohttp_client, app, fresh_auth_store, patched_user_dir
|
||||||
|
):
|
||||||
|
"""Token present, whoami works → username is returned."""
|
||||||
|
fresh_auth_store.set_token(Token(
|
||||||
|
access_token="x", refresh_token=None, expires_at=9999999999.0,
|
||||||
|
))
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes._whoami_username",
|
||||||
|
return_value="alice",
|
||||||
|
):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/api/hf-auth-token-status")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json()) == {"token_available": True, "username": "alice"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch):
|
||||||
|
"""Not loopback / multi-user → 403."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: False,
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 403
|
||||||
|
assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch):
|
||||||
|
"""Eligible + first attempt → 200 with authorize_url."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.start_login_flow",
|
||||||
|
AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"),
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json())["authorize_url"].startswith(
|
||||||
|
"https://huggingface.co/oauth/authorize"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch):
|
||||||
|
"""Lock already held → 409."""
|
||||||
|
from app.model_downloader.hf_auth.oauth import OAuthInProgressError
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.start_login_flow",
|
||||||
|
AsyncMock(side_effect=OAuthInProgressError()),
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 409
|
||||||
|
assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_login_start_503_when_callback_bind_fails(
|
||||||
|
aiohttp_client, app, monkeypatch
|
||||||
|
):
|
||||||
|
"""Callback server failed to bind (e.g. port busy) → 503, not a dead URL."""
|
||||||
|
from app.model_downloader.hf_auth.oauth import OAuthCallbackError
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.start_login_flow",
|
||||||
|
AsyncMock(side_effect=OAuthCallbackError("could not bind callback port")),
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 503
|
||||||
|
assert (await resp.json())["error"]["code"] == "HF_AUTH_START_FAILED"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_auth_logout_clears_store(
|
||||||
|
aiohttp_client, app, fresh_auth_store, patched_user_dir
|
||||||
|
):
|
||||||
|
fresh_auth_store.set_token(Token(
|
||||||
|
access_token="x", refresh_token=None, expires_at=9999999999.0,
|
||||||
|
))
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-logout")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json()) == {"logged_out": True}
|
||||||
|
assert not fresh_auth_store.has_token()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch):
|
||||||
|
"""The availability response embeds {token_available, eligible}."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/models-availability-status",
|
||||||
|
json={"models": {}},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert "hf_auth" in data
|
||||||
|
assert data["hf_auth"] == {"token_available": False, "eligible": True}
|
||||||
514
tests-unit/app_test/model_downloader_test.py
Normal file
514
tests-unit/app_test/model_downloader_test.py
Normal file
@ -0,0 +1,514 @@
|
|||||||
|
"""Unit tests for the server-side model download subsystem.
|
||||||
|
|
||||||
|
Covers the pieces that don't require talking to a real network:
|
||||||
|
|
||||||
|
- path parsing & allowlist (pure functions)
|
||||||
|
- DownloadServer registry lifecycle (in-memory state)
|
||||||
|
- API routes via aiohttp_client + folder_paths/probe_url patches
|
||||||
|
|
||||||
|
Streaming downloads themselves are exercised indirectly — the route-level
|
||||||
|
tests stub out the network probe so we can verify the gating logic in
|
||||||
|
``download_models`` without making real HTTP calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch, AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from app.model_downloader.allowlist import is_url_allowed
|
||||||
|
from app.model_downloader.api.routes import register_routes
|
||||||
|
from app.model_downloader.download_server import DownloadServer
|
||||||
|
from app.model_downloader.gated_detection import MetadataProbeResult
|
||||||
|
from app.model_downloader.paths import (
|
||||||
|
InvalidModelId,
|
||||||
|
parse_model_id,
|
||||||
|
resolve_destination,
|
||||||
|
resolve_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Fixtures
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_root(tmp_path):
|
||||||
|
"""A fake ``models/`` root with two registered folder types."""
|
||||||
|
loras_dir = tmp_path / "loras"
|
||||||
|
checkpoints_dir = tmp_path / "checkpoints"
|
||||||
|
loras_dir.mkdir()
|
||||||
|
checkpoints_dir.mkdir()
|
||||||
|
return tmp_path, loras_dir, checkpoints_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_folder_paths(model_root):
|
||||||
|
"""Point folder_paths at our fake roots for the duration of one test."""
|
||||||
|
_root, loras_dir, checkpoints_dir = model_root
|
||||||
|
mapping = {
|
||||||
|
"loras": ([str(loras_dir)], {".safetensors"}),
|
||||||
|
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"folder_paths.folder_names_and_paths", mapping
|
||||||
|
), patch(
|
||||||
|
"folder_paths.get_folder_paths",
|
||||||
|
side_effect=lambda name: mapping.get(name, ([], set()))[0],
|
||||||
|
):
|
||||||
|
yield mapping
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_download_server():
|
||||||
|
"""Reset the module-level singleton between tests so registry state
|
||||||
|
doesn't leak across tests sharing the singleton."""
|
||||||
|
from app.model_downloader.download_server import DOWNLOAD_SERVER
|
||||||
|
|
||||||
|
DOWNLOAD_SERVER.reset_for_tests()
|
||||||
|
yield DOWNLOAD_SERVER
|
||||||
|
DOWNLOAD_SERVER.reset_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(patched_folder_paths, fresh_download_server):
|
||||||
|
app = web.Application()
|
||||||
|
register_routes(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Pure helpers: allowlist + path parsing
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_accepts_hf_safetensors():
|
||||||
|
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_accepts_civitai_pth():
|
||||||
|
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_rejects_unknown_host():
|
||||||
|
assert not is_url_allowed("https://example.com/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_rejects_api_path_on_hf():
|
||||||
|
# On an allowlisted host but not pointing at a model file.
|
||||||
|
assert not is_url_allowed("https://huggingface.co/api/models")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_rejects_non_https_except_localhost():
|
||||||
|
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
|
||||||
|
assert is_url_allowed("http://localhost:8000/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_valid(patched_folder_paths):
|
||||||
|
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_rejects_traversal(patched_folder_paths):
|
||||||
|
with pytest.raises(InvalidModelId):
|
||||||
|
parse_model_id("../etc/passwd")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
|
||||||
|
with pytest.raises(InvalidModelId):
|
||||||
|
parse_model_id("nope/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
|
||||||
|
with pytest.raises(InvalidModelId):
|
||||||
|
parse_model_id("loras/sub/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
target = loras_dir / "foo.safetensors"
|
||||||
|
target.write_bytes(b"x")
|
||||||
|
assert resolve_existing("loras/foo.safetensors") == str(target)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
|
||||||
|
assert resolve_existing("loras/missing.safetensors") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
final, tmp = resolve_destination("loras/foo.safetensors", epoch=7)
|
||||||
|
assert final == str(loras_dir / "foo.safetensors")
|
||||||
|
# Temp path embeds the session epoch (so cancel+retry can't collide on it)
|
||||||
|
# and uses the subsystem-specific suffix the startup sweep matches.
|
||||||
|
assert tmp == f"{final}.7.comfy-download.tmp"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_is_exclusive():
|
||||||
|
server = DownloadServer()
|
||||||
|
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
|
||||||
|
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
|
||||||
|
assert s1 is not None
|
||||||
|
assert s2 is None
|
||||||
|
assert server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_removes_session():
|
||||||
|
server = DownloadServer()
|
||||||
|
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
|
||||||
|
assert server.cancel("loras/x.safetensors") is True
|
||||||
|
assert not server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_returns_false_when_absent():
|
||||||
|
server = DownloadServer()
|
||||||
|
assert server.cancel("loras/never.safetensors") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_only_clears_matching_epoch():
|
||||||
|
"""If a session is cancelled and a new one for the same id is
|
||||||
|
registered, ``finish`` from the original worker must not evict the
|
||||||
|
newer session."""
|
||||||
|
server = DownloadServer()
|
||||||
|
s_old = server.try_register("loras/x.safetensors", "u1")
|
||||||
|
server.cancel("loras/x.safetensors")
|
||||||
|
s_new = server.try_register("loras/x.safetensors", "u2")
|
||||||
|
assert s_new is not None and s_new.epoch != s_old.epoch
|
||||||
|
# Old worker's late finish() is a no-op:
|
||||||
|
server.finish(s_old)
|
||||||
|
assert server.is_downloading("loras/x.safetensors")
|
||||||
|
server.finish(s_new)
|
||||||
|
assert not server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_active_follows_cancellation():
|
||||||
|
server = DownloadServer()
|
||||||
|
s = server.try_register("loras/x.safetensors", "u")
|
||||||
|
assert server.is_active(s)
|
||||||
|
server.cancel("loras/x.safetensors")
|
||||||
|
assert not server.is_active(s)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_progress_tracks_fraction():
|
||||||
|
server = DownloadServer()
|
||||||
|
s = server.try_register("loras/x.safetensors", "u")
|
||||||
|
server.update_progress(s, 50, 100)
|
||||||
|
snap = server.snapshot()["loras/x.safetensors"]
|
||||||
|
assert snap.bytes_downloaded == 50
|
||||||
|
assert snap.total_bytes == 100
|
||||||
|
assert snap.progress == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_progress_with_unknown_total_keeps_progress_none():
|
||||||
|
server = DownloadServer()
|
||||||
|
s = server.try_register("loras/x.safetensors", "u")
|
||||||
|
server.update_progress(s, 50, None)
|
||||||
|
assert server.snapshot()["loras/x.safetensors"].progress is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_orphan_tmp_files(model_root):
|
||||||
|
"""Orphan temp left by a crashed download must be swept on first use,
|
||||||
|
while unrelated *.tmp files in the model dir are left untouched."""
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
orphan = loras_dir / "stale.safetensors.3.comfy-download.tmp"
|
||||||
|
orphan.write_bytes(b"partial")
|
||||||
|
unrelated = loras_dir / "someothertool.tmp"
|
||||||
|
unrelated.write_bytes(b"not ours")
|
||||||
|
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
|
||||||
|
with patch("folder_paths.folder_names_and_paths", mapping), patch(
|
||||||
|
"folder_paths.get_folder_paths",
|
||||||
|
side_effect=lambda name: mapping.get(name, ([], set()))[0],
|
||||||
|
):
|
||||||
|
server = DownloadServer()
|
||||||
|
assert orphan.exists(), "sweep must not run at construction time"
|
||||||
|
server.sweep_orphan_tmp_files()
|
||||||
|
assert not orphan.exists()
|
||||||
|
assert unrelated.exists(), "unrelated .tmp must not be swept"
|
||||||
|
# Idempotent — a second call is a cheap no-op.
|
||||||
|
server.sweep_orphan_tmp_files()
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Route: POST /api/models-availability-status
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_availability_partitions_correctly(
|
||||||
|
aiohttp_client, app, model_root, fresh_download_server
|
||||||
|
):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
(loras_dir / "present.safetensors").write_bytes(b"x")
|
||||||
|
fresh_download_server.try_register(
|
||||||
|
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
# Stub probes — we're testing state assignment, not network calls.
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(
|
||||||
|
file_size=None, is_hf_downloadable=None,
|
||||||
|
)),
|
||||||
|
):
|
||||||
|
body = {
|
||||||
|
"models": {
|
||||||
|
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
|
||||||
|
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
|
||||||
|
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp = await client.post("/api/models-availability-status", json=body)
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
models = data["models"]
|
||||||
|
assert models["loras/present.safetensors"]["state"] == "available"
|
||||||
|
assert models["loras/missing.safetensors"]["state"] == "missing"
|
||||||
|
assert models["loras/inflight.safetensors"]["state"] == "downloading"
|
||||||
|
assert "hf_auth" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(
|
||||||
|
file_size=None, is_hf_downloadable=None,
|
||||||
|
)),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/models-availability-status",
|
||||||
|
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert data["models"]["../etc/passwd"]["state"] == "missing"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Route: POST /api/download-models — precondition gating
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
err = (await resp.json())["error"]
|
||||||
|
assert err["code"] == "URL_NOT_ALLOWED"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_rejects_already_available(
|
||||||
|
aiohttp_client, app, model_root
|
||||||
|
):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
(loras_dir / "x.safetensors").write_bytes(b"x")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 409
|
||||||
|
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_rejects_already_downloading(
|
||||||
|
aiohttp_client, app, fresh_download_server
|
||||||
|
):
|
||||||
|
fresh_download_server.try_register(
|
||||||
|
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 409
|
||||||
|
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_rejects_gated_model(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_atomic_failure_does_not_register_partial(
|
||||||
|
aiohttp_client, app, model_root, fresh_download_server
|
||||||
|
):
|
||||||
|
"""If one model in a batch fails, none get registered."""
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
(loras_dir / "already.safetensors").write_bytes(b"x")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={
|
||||||
|
"models": {
|
||||||
|
"loras/already.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/already.safetensors",
|
||||||
|
"loras/new.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/new.safetensors",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 409
|
||||||
|
# The "new" model should not have been registered as part of the
|
||||||
|
# failed batch.
|
||||||
|
assert not fresh_download_server.is_downloading("loras/new.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_schedules_when_all_preconditions_pass(
|
||||||
|
aiohttp_client, app, fresh_download_server
|
||||||
|
):
|
||||||
|
"""Verify the precondition pass, registration pass, and async
|
||||||
|
scheduling all wire up correctly. We patch the streamer to avoid
|
||||||
|
real HTTP while still letting the route execute end-to-end."""
|
||||||
|
started = asyncio.Event()
|
||||||
|
finish_signal = asyncio.Event()
|
||||||
|
|
||||||
|
async def fake_stream(session):
|
||||||
|
started.set()
|
||||||
|
await finish_signal.wait()
|
||||||
|
from app.model_downloader.download_server import DOWNLOAD_SERVER
|
||||||
|
DOWNLOAD_SERVER.finish(session)
|
||||||
|
return "/dev/null"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
|
||||||
|
):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/new.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/new.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 202
|
||||||
|
body = await resp.json()
|
||||||
|
assert body["accepted"] is True
|
||||||
|
assert body["scheduled"] == ["loras/new.safetensors"]
|
||||||
|
# Wait for the worker to actually start.
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=2.0)
|
||||||
|
assert fresh_download_server.is_downloading("loras/new.safetensors")
|
||||||
|
finish_signal.set()
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Route: POST /api/cancel-model-download-session
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_removes_active_session(
|
||||||
|
aiohttp_client, app, fresh_download_server
|
||||||
|
):
|
||||||
|
fresh_download_server.try_register(
|
||||||
|
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/cancel-model-download-session",
|
||||||
|
json={"model_id": "loras/x.safetensors"},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json())["cancelled"] is True
|
||||||
|
assert not fresh_download_server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_returns_404_when_none(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/cancel-model-download-session",
|
||||||
|
json={"model_id": "loras/nothing.safetensors"},
|
||||||
|
)
|
||||||
|
assert resp.status == 404
|
||||||
|
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Unified availability response embeds metadata per id
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_availability_embeds_metadata(aiohttp_client, app):
|
||||||
|
"""``file_size`` + ``is_hf_downloadable`` come back on the same
|
||||||
|
request as the state — no separate metadata endpoint."""
|
||||||
|
results = {
|
||||||
|
"https://huggingface.co/a/b/resolve/main/free.safetensors":
|
||||||
|
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
|
||||||
|
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
|
||||||
|
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_probe(url):
|
||||||
|
return results[url]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url", new=fake_probe
|
||||||
|
):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/models-availability-status",
|
||||||
|
json={
|
||||||
|
"models": {
|
||||||
|
"loras/free.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/free.safetensors",
|
||||||
|
"loras/gated.safetensors":
|
||||||
|
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
models = (await resp.json())["models"]
|
||||||
|
assert models["loras/free.safetensors"]["file_size"] == 1024
|
||||||
|
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
|
||||||
|
assert models["loras/gated.safetensors"]["file_size"] is None
|
||||||
|
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False
|
||||||
Loading…
Reference in New Issue
Block a user