ComfyUI/app/model_downloader/api/routes.py
DoronGenzelHass fdd84d04a0 feat(model_downloader): server-side model download + HuggingFace OAuth subsystem
Self-contained package under app/model_downloader/:
- Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension).
- Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep.
- Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll.
- HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh.
- Pydantic request/response schemas and aiohttp routes under api/.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-22 15:16:59 +03:00

333 lines
12 KiB
Python

"""Aiohttp routes for the server-side model download subsystem.
Endpoint surface (all under ``/api/``, all kebab-case):
- ``POST /api/models-availability-status`` — bulk status + metadata query.
- ``POST /api/download-models`` — start a batch of downloads.
- ``POST /api/cancel-model-download-session`` — cancel a single in-flight one.
- ``GET /api/hf-auth-token-status`` — current HF login state.
- ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow.
- ``POST /api/hf-auth-logout`` — drop the stored HF token.
The contract is intentionally narrow: only model_ids of the form
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
are accepted, and only URLs on the same allowlist the frontend already
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
to keep the server out of the SSRF business for this feature.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Optional
from aiohttp import web
from pydantic import BaseModel, ValidationError
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadSession,
)
from app.model_downloader.downloader import schedule_batch
from app.model_downloader.gated_detection import probe_url
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
from app.model_downloader.hf_auth.oauth import (
OAuthInProgressError,
start_login_flow,
)
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_existing,
)
from app.model_downloader.api import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
def register_routes(app: web.Application) -> None:
"""Wire the model-downloader routes into the running aiohttp app.
Called once from ``server.py`` during ``PromptServer`` startup.
"""
app.add_routes(ROUTES)
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _validation_error(code: str, ve: ValidationError) -> web.Response:
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
return web.json_response(
payload.model_dump(mode="json", exclude_none=False),
status=status,
)
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
"""Parse a JSON body into a pydantic model or raise a 400 response."""
try:
raw = await request.json()
except json.JSONDecodeError:
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
return model.model_validate(raw)
except ValidationError as ve:
return _validation_error("INVALID_BODY", ve)
# ----- 1. availability status (unified: state + metadata per id) -----
@ROUTES.post("/api/models-availability-status")
async def models_availability_status(request: web.Request) -> web.Response:
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
State (``available`` / ``missing`` / ``downloading``) is cheap to
recompute per call. ``file_size`` and ``is_gated`` are cached
server-side per URL. ``is_hf_downloadable`` is recomputed every
call from the current token state — that's what makes login + license
acceptance show up in the UI within one poll cycle without any
frontend cache plumbing.
"""
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
if isinstance(parsed, web.Response):
return parsed
items = list(parsed.models.items())
# Run all probes concurrently; each is internally cached per URL.
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
for (model_id, _url), probe in zip(items, probes):
try:
parse_model_id(model_id)
except InvalidModelId:
# Ill-formed identifier: report as missing without 400-ing the
# whole batch — the workflow author probably typo'd.
response_models[model_id] = schemas_out.ModelStatusEntry(
state="missing",
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
active = DOWNLOAD_SERVER.get(model_id)
if active is not None:
response_models[model_id] = schemas_out.ModelStatusEntry(
state="downloading",
progress=schemas_out.DownloadProgress(
bytes_downloaded=active.bytes_downloaded,
total_bytes=active.total_bytes,
progress=active.progress,
),
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
state: schemas_out.ModelState = (
"available" if resolve_existing(model_id) is not None else "missing"
)
response_models[model_id] = schemas_out.ModelStatusEntry(
state=state,
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
return _ok(schemas_out.AvailabilityStatusResponse(
models=response_models,
hf_auth=schemas_out.HfAuthStatus(
token_available=HF_AUTH_STORE.has_token(),
eligible=is_hf_auth_eligible(),
),
))
# ----- 2. start downloads -----
@ROUTES.post("/api/download-models")
async def download_models(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
if isinstance(parsed, web.Response):
return parsed
if not parsed.models:
return _error(400, "EMPTY_REQUEST", "No models supplied.")
# ----- precondition pass: validate everything BEFORE registering anything -----
# Atomic semantics: if any model fails any precondition (invalid id,
# not allow-listed URL, already on disk, already downloading, or gated),
# the entire request fails and no state is changed.
requested = list(parsed.models.items())
for model_id, url in requested:
try:
parse_model_id(model_id)
except InvalidModelId as e:
return _error(400, "INVALID_MODEL_ID", str(e),
{"model_id": model_id})
if not is_url_allowed(url):
return _error(
400, "URL_NOT_ALLOWED",
"Server-side downloads only accept HuggingFace, Civitai, "
"or localhost URLs ending in a known model extension.",
{"model_id": model_id, "url": url},
)
if resolve_existing(model_id) is not None:
return _error(409, "ALREADY_AVAILABLE",
f"Model already exists on disk: {model_id}",
{"model_id": model_id})
if DOWNLOAD_SERVER.is_downloading(model_id):
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress.",
{"model_id": model_id})
# Reachability check last — it's the only one that talks to the
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
# reflects current token access; for non-HF URLs it's None, and we
# treat that as "no info, proceed".
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
for (model_id, url), probe in zip(requested, probes):
if probe.is_hf_downloadable is False:
return _error(
400, "MODEL_NOT_DOWNLOADABLE",
f"Model {model_id} is gated on HuggingFace and the current "
f"server token (if any) does not grant access.",
{"model_id": model_id, "url": url},
)
# ----- registration pass: try_register is atomic per model_id -----
# Defensive: another request might have raced past our pre-check
# between the loop above and here. try_register handles that.
sessions: list[DownloadSession] = []
for model_id, url in requested:
session = DOWNLOAD_SERVER.try_register(model_id, url)
if session is None:
# Race: someone else got in. Roll back what we registered.
for s in sessions:
DOWNLOAD_SERVER.cancel(s.model_id)
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress (race).",
{"model_id": model_id})
sessions.append(session)
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
schedule_batch(sessions)
logging.info(
"[model_downloader] scheduled %d downloads: %s",
len(sessions), [s.model_id for s in sessions],
)
return _ok(schemas_out.DownloadModelsResponse(
accepted=True,
scheduled=[s.model_id for s in sessions],
), status=202)
# ----- 3. cancel a session -----
@ROUTES.post("/api/cancel-model-download-session")
async def cancel_model_download_session(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
if isinstance(parsed, web.Response):
return parsed
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
if not cancelled:
return _error(404, "NOT_DOWNLOADING",
f"No active download for {parsed.model_id}.",
{"model_id": parsed.model_id})
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
# ----- 4. HuggingFace OAuth status / login start / logout -----
@ROUTES.get("/api/hf-auth-token-status")
async def hf_auth_token_status(request: web.Request) -> web.Response:
"""Return whether the server holds a usable HF token + its username.
Used by the settings UI and (out-of-band) by the frontend on
login completion. ``token_available`` is true even if the cached
access_token is expired — as long as a refresh_token exists, the
user is "logged in" from their perspective.
"""
token_present = HF_AUTH_STORE.has_token()
username: Optional[str] = None
if token_present:
# Resolve the username via whoami. Done in a worker thread because
# huggingface_hub's whoami is synchronous + blocks on a network call.
tok = await HF_AUTH_STORE.get_valid_token()
if tok is not None:
try:
username = await asyncio.to_thread(_whoami_username, tok.access_token)
except Exception as e:
logging.debug("[hf_auth] whoami failed: %s", e)
return _ok(schemas_out.HfAuthTokenStatusResponse(
token_available=token_present,
username=username,
))
def _whoami_username(token: str) -> Optional[str]:
"""Sync helper: ask HF for the user name attached to a token."""
from huggingface_hub import HfApi
info = HfApi().whoami(token=token)
if isinstance(info, dict):
return info.get("name") or info.get("fullname")
return None
@ROUTES.post("/api/hf-auth-login-start")
async def hf_auth_login_start(request: web.Request) -> web.Response:
"""Begin one OAuth attempt: bind the callback port, return the URL.
Rejected outright if this deployment isn't eligible (we don't
surface the option on multi-tenant / public-IP installs).
"""
if not is_hf_auth_eligible():
return _error(
403, "HF_AUTH_NOT_ELIGIBLE",
"This server is not eligible for interactive HuggingFace login. "
"It must be bound to a loopback address and not running in "
"--multi-user mode.",
)
try:
url = await start_login_flow()
except OAuthInProgressError:
return _error(
409, "HF_AUTH_IN_PROGRESS",
"Another HuggingFace login attempt is in progress. Try again "
"after it completes or times out.",
)
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
@ROUTES.post("/api/hf-auth-logout")
async def hf_auth_logout(request: web.Request) -> web.Response:
"""Drop the in-memory + on-disk HF token."""
HF_AUTH_STORE.clear()
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))