feat(model_downloader): server-side model download + HuggingFace OAuth subsystem

Self-contained package under app/model_downloader/:
- Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension).
- Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep.
- Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll.
- HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh.
- Pydantic request/response schemas and aiohttp routes under api/.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
DoronGenzelHass 2026-06-22 12:02:46 +03:00
parent 4676481609
commit fdd84d04a0
14 changed files with 1846 additions and 0 deletions

View File

@ -0,0 +1,46 @@
"""URL allowlist for server-side model fetches.
Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows
agree on which URLs are eligible for download. Server-side allowlisting is
the primary SSRF defense for this subsystem workflow JSON is untrusted
input (anyone can hand-craft one), so we never let the server fetch URLs
outside this list.
"""
from urllib.parse import urlparse
# Frontend parity: ``missingModelDownload-*.js`` exports the same triple as
# ``i = [...]`` (Civitai / HuggingFace / localhost).
_ALLOWED_URL_PREFIXES = (
"https://huggingface.co/",
"https://civitai.com/",
"http://localhost:",
"http://127.0.0.1:",
)
# Frontend parity: same set as ``a = [...]`` in the bundle.
_ALLOWED_MODEL_EXTENSIONS = (
".safetensors",
".sft",
".ckpt",
".pth",
".pt",
)
def is_url_allowed(url: str) -> bool:
"""Check whether ``url`` is permitted as a server-side download source.
Returns True only when both:
- the URL starts with one of the allowed prefixes, AND
- the URL's final path segment ends with a known model extension.
Both checks are required to keep arbitrary HTML / API endpoints on
allowlisted hosts (e.g. ``https://huggingface.co/api/...``) off the table.
"""
if not isinstance(url, str) or not url:
return False
if not any(url.startswith(p) for p in _ALLOWED_URL_PREFIXES):
return False
path = urlparse(url).path
return any(path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS)

View File

@ -0,0 +1,332 @@
"""Aiohttp routes for the server-side model download subsystem.
Endpoint surface (all under ``/api/``, all kebab-case):
- ``POST /api/models-availability-status`` bulk status + metadata query.
- ``POST /api/download-models`` start a batch of downloads.
- ``POST /api/cancel-model-download-session`` cancel a single in-flight one.
- ``GET /api/hf-auth-token-status`` current HF login state.
- ``POST /api/hf-auth-login-start`` begin the HF OAuth flow.
- ``POST /api/hf-auth-logout`` drop the stored HF token.
The contract is intentionally narrow: only model_ids of the form
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
are accepted, and only URLs on the same allowlist the frontend already
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
to keep the server out of the SSRF business for this feature.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Optional
from aiohttp import web
from pydantic import BaseModel, ValidationError
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadSession,
)
from app.model_downloader.downloader import schedule_batch
from app.model_downloader.gated_detection import probe_url
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
from app.model_downloader.hf_auth.oauth import (
OAuthInProgressError,
start_login_flow,
)
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_existing,
)
from app.model_downloader.api import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
def register_routes(app: web.Application) -> None:
"""Wire the model-downloader routes into the running aiohttp app.
Called once from ``server.py`` during ``PromptServer`` startup.
"""
app.add_routes(ROUTES)
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response(
{"error": {"code": code, "message": message, "details": details or {}}},
status=status,
)
def _validation_error(code: str, ve: ValidationError) -> web.Response:
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
return web.json_response(
payload.model_dump(mode="json", exclude_none=False),
status=status,
)
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
"""Parse a JSON body into a pydantic model or raise a 400 response."""
try:
raw = await request.json()
except json.JSONDecodeError:
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
return model.model_validate(raw)
except ValidationError as ve:
return _validation_error("INVALID_BODY", ve)
# ----- 1. availability status (unified: state + metadata per id) -----
@ROUTES.post("/api/models-availability-status")
async def models_availability_status(request: web.Request) -> web.Response:
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
State (``available`` / ``missing`` / ``downloading``) is cheap to
recompute per call. ``file_size`` and ``is_gated`` are cached
server-side per URL. ``is_hf_downloadable`` is recomputed every
call from the current token state that's what makes login + license
acceptance show up in the UI within one poll cycle without any
frontend cache plumbing.
"""
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
if isinstance(parsed, web.Response):
return parsed
items = list(parsed.models.items())
# Run all probes concurrently; each is internally cached per URL.
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
for (model_id, _url), probe in zip(items, probes):
try:
parse_model_id(model_id)
except InvalidModelId:
# Ill-formed identifier: report as missing without 400-ing the
# whole batch — the workflow author probably typo'd.
response_models[model_id] = schemas_out.ModelStatusEntry(
state="missing",
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
active = DOWNLOAD_SERVER.get(model_id)
if active is not None:
response_models[model_id] = schemas_out.ModelStatusEntry(
state="downloading",
progress=schemas_out.DownloadProgress(
bytes_downloaded=active.bytes_downloaded,
total_bytes=active.total_bytes,
progress=active.progress,
),
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
continue
state: schemas_out.ModelState = (
"available" if resolve_existing(model_id) is not None else "missing"
)
response_models[model_id] = schemas_out.ModelStatusEntry(
state=state,
file_size=probe.file_size,
is_hf_downloadable=probe.is_hf_downloadable,
)
return _ok(schemas_out.AvailabilityStatusResponse(
models=response_models,
hf_auth=schemas_out.HfAuthStatus(
token_available=HF_AUTH_STORE.has_token(),
eligible=is_hf_auth_eligible(),
),
))
# ----- 2. start downloads -----
@ROUTES.post("/api/download-models")
async def download_models(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
if isinstance(parsed, web.Response):
return parsed
if not parsed.models:
return _error(400, "EMPTY_REQUEST", "No models supplied.")
# ----- precondition pass: validate everything BEFORE registering anything -----
# Atomic semantics: if any model fails any precondition (invalid id,
# not allow-listed URL, already on disk, already downloading, or gated),
# the entire request fails and no state is changed.
requested = list(parsed.models.items())
for model_id, url in requested:
try:
parse_model_id(model_id)
except InvalidModelId as e:
return _error(400, "INVALID_MODEL_ID", str(e),
{"model_id": model_id})
if not is_url_allowed(url):
return _error(
400, "URL_NOT_ALLOWED",
"Server-side downloads only accept HuggingFace, Civitai, "
"or localhost URLs ending in a known model extension.",
{"model_id": model_id, "url": url},
)
if resolve_existing(model_id) is not None:
return _error(409, "ALREADY_AVAILABLE",
f"Model already exists on disk: {model_id}",
{"model_id": model_id})
if DOWNLOAD_SERVER.is_downloading(model_id):
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress.",
{"model_id": model_id})
# Reachability check last — it's the only one that talks to the
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
# reflects current token access; for non-HF URLs it's None, and we
# treat that as "no info, proceed".
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
for (model_id, url), probe in zip(requested, probes):
if probe.is_hf_downloadable is False:
return _error(
400, "MODEL_NOT_DOWNLOADABLE",
f"Model {model_id} is gated on HuggingFace and the current "
f"server token (if any) does not grant access.",
{"model_id": model_id, "url": url},
)
# ----- registration pass: try_register is atomic per model_id -----
# Defensive: another request might have raced past our pre-check
# between the loop above and here. try_register handles that.
sessions: list[DownloadSession] = []
for model_id, url in requested:
session = DOWNLOAD_SERVER.try_register(model_id, url)
if session is None:
# Race: someone else got in. Roll back what we registered.
for s in sessions:
DOWNLOAD_SERVER.cancel(s.model_id)
return _error(409, "ALREADY_DOWNLOADING",
f"A download for {model_id} is already in progress (race).",
{"model_id": model_id})
sessions.append(session)
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
schedule_batch(sessions)
logging.info(
"[model_downloader] scheduled %d downloads: %s",
len(sessions), [s.model_id for s in sessions],
)
return _ok(schemas_out.DownloadModelsResponse(
accepted=True,
scheduled=[s.model_id for s in sessions],
), status=202)
# ----- 3. cancel a session -----
@ROUTES.post("/api/cancel-model-download-session")
async def cancel_model_download_session(request: web.Request) -> web.Response:
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
if isinstance(parsed, web.Response):
return parsed
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
if not cancelled:
return _error(404, "NOT_DOWNLOADING",
f"No active download for {parsed.model_id}.",
{"model_id": parsed.model_id})
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
# ----- 4. HuggingFace OAuth status / login start / logout -----
@ROUTES.get("/api/hf-auth-token-status")
async def hf_auth_token_status(request: web.Request) -> web.Response:
"""Return whether the server holds a usable HF token + its username.
Used by the settings UI and (out-of-band) by the frontend on
login completion. ``token_available`` is true even if the cached
access_token is expired as long as a refresh_token exists, the
user is "logged in" from their perspective.
"""
token_present = HF_AUTH_STORE.has_token()
username: Optional[str] = None
if token_present:
# Resolve the username via whoami. Done in a worker thread because
# huggingface_hub's whoami is synchronous + blocks on a network call.
tok = await HF_AUTH_STORE.get_valid_token()
if tok is not None:
try:
username = await asyncio.to_thread(_whoami_username, tok.access_token)
except Exception as e:
logging.debug("[hf_auth] whoami failed: %s", e)
return _ok(schemas_out.HfAuthTokenStatusResponse(
token_available=token_present,
username=username,
))
def _whoami_username(token: str) -> Optional[str]:
"""Sync helper: ask HF for the user name attached to a token."""
from huggingface_hub import HfApi
info = HfApi().whoami(token=token)
if isinstance(info, dict):
return info.get("name") or info.get("fullname")
return None
@ROUTES.post("/api/hf-auth-login-start")
async def hf_auth_login_start(request: web.Request) -> web.Response:
"""Begin one OAuth attempt: bind the callback port, return the URL.
Rejected outright if this deployment isn't eligible (we don't
surface the option on multi-tenant / public-IP installs).
"""
if not is_hf_auth_eligible():
return _error(
403, "HF_AUTH_NOT_ELIGIBLE",
"This server is not eligible for interactive HuggingFace login. "
"It must be bound to a loopback address and not running in "
"--multi-user mode.",
)
try:
url = await start_login_flow()
except OAuthInProgressError:
return _error(
409, "HF_AUTH_IN_PROGRESS",
"Another HuggingFace login attempt is in progress. Try again "
"after it completes or times out.",
)
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
@ROUTES.post("/api/hf-auth-logout")
async def hf_auth_logout(request: web.Request) -> web.Response:
"""Drop the in-memory + on-disk HF token."""
HF_AUTH_STORE.clear()
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))

View File

@ -0,0 +1,41 @@
"""Request schemas for the model-downloader API.
Each endpoint accepts a small JSON body. Pydantic enforces the shape at
the boundary; route handlers operate only on validated values past that.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class AvailabilityStatusRequest(BaseModel):
"""``POST /api/models-availability-status``.
Sent by the frontend on each poll. Each entry is ``{model_id: url}``;
the URL is the one declared in ``properties.models[i].url`` in the
workflow JSON and lets the server compute per-id metadata
(``file_size`` + ``is_hf_downloadable``) on the same request.
"""
models: dict[str, str] = Field(default_factory=dict)
class DownloadModelsRequest(BaseModel):
"""``POST /api/download-models``.
Same shape as the metadata request the URL for each model_id.
Returns immediately after validation and scheduling.
"""
models: dict[str, str] = Field(default_factory=dict)
class CancelDownloadSessionRequest(BaseModel):
"""``POST /api/cancel-model-download-session``."""
model_id: str
__all__ = [
"AvailabilityStatusRequest",
"DownloadModelsRequest",
"CancelDownloadSessionRequest",
]

View File

@ -0,0 +1,81 @@
"""Response schemas for the model-downloader API."""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel
ModelState = Literal["available", "missing", "downloading"]
class DownloadProgress(BaseModel):
"""Embedded in a model entry when its state is ``downloading``."""
bytes_downloaded: int
total_bytes: Optional[int] = None
progress: Optional[float] = None # fraction in [0,1]; null until total known
class ModelStatusEntry(BaseModel):
"""Everything the UI needs to render one row, in one shot.
``state`` reflects what the server has on disk + in-flight; ``file_size``
and ``is_hf_downloadable`` come from probes (intrinsic; cached).
The HF fields are populated for every poll (cached on the server),
so license-acceptance flips show up within one poll interval without
any frontend cache invalidation.
"""
state: ModelState
progress: Optional[DownloadProgress] = None
file_size: Optional[int] = None
# HF-only: True iff the server can fetch this URL with current auth
# state. False iff gated and lacking access. None for non-HF URLs.
is_hf_downloadable: Optional[bool] = None
class HfAuthStatus(BaseModel):
"""Snapshot of HF login state, embedded in availability response."""
token_available: bool
eligible: bool
class AvailabilityStatusResponse(BaseModel):
models: dict[str, ModelStatusEntry]
hf_auth: HfAuthStatus
class DownloadModelsResponse(BaseModel):
accepted: bool
scheduled: list[str]
class CancelDownloadSessionResponse(BaseModel):
cancelled: bool
class HfAuthTokenStatusResponse(BaseModel):
token_available: bool
username: Optional[str] = None
class HfAuthLoginStartResponse(BaseModel):
authorize_url: str
class HfAuthLogoutResponse(BaseModel):
logged_out: bool
__all__ = [
"ModelState",
"DownloadProgress",
"ModelStatusEntry",
"HfAuthStatus",
"AvailabilityStatusResponse",
"DownloadModelsResponse",
"CancelDownloadSessionResponse",
"HfAuthTokenStatusResponse",
"HfAuthLoginStartResponse",
"HfAuthLogoutResponse",
]

View File

@ -0,0 +1,179 @@
"""Process-wide registry of in-flight model downloads.
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
server-side model fetch. Designed to be safe with multiple concurrent
clients hitting the API: each model_id has at most one active session,
and the API rejects requests that conflict with in-flight downloads.
Cancellation is cooperative the download loop checks ``is_active`` on
its own session between chunks and raises ``DownloadCancelled`` when the
session has been removed from the registry. This avoids the complications
of ``Task.cancel()`` from outside the loop while still giving deterministic
rollback semantics (the worker is responsible for deleting its own
``.tmp`` on the cancel path).
"""
from __future__ import annotations
import logging
import os
import threading
from dataclasses import dataclass, field
from typing import Optional
from app.model_downloader.paths import iter_all_tmp_paths
class DownloadCancelled(Exception):
"""Raised by the streaming loop when its session has been removed
from the registry (cancellation request) and the worker should roll
back its ``.tmp`` file."""
@dataclass
class DownloadSession:
"""One in-flight download.
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
byte arrives and we know whether the response carries a
``Content-Length``. ``total_bytes`` mirrors that header when present.
"""
model_id: str
url: str
progress: Optional[float] = None
bytes_downloaded: int = 0
total_bytes: Optional[int] = None
# Sequence number used solely as identity for the cancellation check —
# so that "cancel + restart" doesn't get confused by stale workers.
epoch: int = field(default_factory=lambda: 0)
class DownloadServer:
"""Singleton registry of active downloads.
All mutation goes through this object so concurrent route handlers
see a consistent view. The ``_lock`` is a plain threading lock
because the registry is consulted from both the asyncio event-loop
thread (route handlers) and from any worker coroutines spawned to
perform downloads.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._sessions: dict[str, DownloadSession] = {}
self._epoch_counter = 0
self._orphan_sweep_done = False
# ----- lifecycle -----
def sweep_orphan_tmp_files(self) -> None:
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
Deferred off the import path so module load doesn't block on
filesystem I/O against potentially-slow mounts. Each route handler
that might create a new ``.tmp`` runs this exactly once.
"""
with self._lock:
if self._orphan_sweep_done:
return
self._orphan_sweep_done = True
for path in iter_all_tmp_paths():
try:
os.remove(path)
logging.info("[model_downloader] removed orphan tmp file: %s", path)
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
# ----- queries -----
def is_downloading(self, model_id: str) -> bool:
with self._lock:
return model_id in self._sessions
def get(self, model_id: str) -> Optional[DownloadSession]:
with self._lock:
return self._sessions.get(model_id)
def snapshot(self) -> dict[str, DownloadSession]:
"""Return a shallow copy of the current sessions map."""
with self._lock:
return dict(self._sessions)
# ----- mutations -----
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
"""Atomically register a new session iff none exists for ``model_id``.
Returns the new session on success, ``None`` if a session is already
in flight. Callers must check the return value the caller is the
sole owner of the session it gets back.
"""
with self._lock:
if model_id in self._sessions:
return None
self._epoch_counter += 1
session = DownloadSession(
model_id=model_id,
url=url,
epoch=self._epoch_counter,
)
self._sessions[model_id] = session
return session
def update_progress(
self,
session: DownloadSession,
bytes_downloaded: int,
total_bytes: Optional[int],
) -> None:
"""Update progress on a session. No-op if the session has been
removed (cancelled) caller should check ``is_active`` separately."""
with self._lock:
current = self._sessions.get(session.model_id)
if current is None or current.epoch != session.epoch:
return
current.bytes_downloaded = bytes_downloaded
current.total_bytes = total_bytes
if total_bytes and total_bytes > 0:
current.progress = min(1.0, bytes_downloaded / total_bytes)
def is_active(self, session: DownloadSession) -> bool:
"""True iff this exact session is still the registered one for
its model_id. False after cancellation, after completion, or if
another session has replaced it."""
with self._lock:
current = self._sessions.get(session.model_id)
return current is not None and current.epoch == session.epoch
def finish(self, session: DownloadSession) -> None:
"""Remove a completed (or cancelled) session from the registry.
Safe to call multiple times. Only removes if the epoch matches
we never accidentally evict a *newer* session for the same model_id.
"""
with self._lock:
current = self._sessions.get(session.model_id)
if current is not None and current.epoch == session.epoch:
del self._sessions[session.model_id]
def reset_for_tests(self) -> None:
"""Clear all sessions and reset the epoch counter. Test-only."""
with self._lock:
self._sessions.clear()
self._epoch_counter = 0
def cancel(self, model_id: str) -> bool:
"""Remove the session registered for ``model_id``.
Returns True if there was an active session to cancel. The worker
will discover the cancellation on its next ``is_active`` check
and roll back its ``.tmp`` file.
"""
with self._lock:
if model_id in self._sessions:
del self._sessions[model_id]
return True
return False
DOWNLOAD_SERVER = DownloadServer()

View File

@ -0,0 +1,205 @@
"""Streaming download worker with progress reporting and cancellation.
Each download writes to ``<final_path>.tmp`` and atomically renames into
place on success. Between chunks the worker checks the registry for
cancellation (via ``DownloadServer.is_active``) and rolls back its
``.tmp`` on cancel or on any error.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Optional
import aiohttp
from app.model_downloader.download_server import (
DOWNLOAD_SERVER,
DownloadCancelled,
DownloadSession,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url
from app.model_downloader.http_client import get_session, parse_content_length
from app.model_downloader.paths import resolve_destination
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
async def stream_to_disk(session: DownloadSession) -> str:
"""Run a single download to completion or cancellation.
Returns the final on-disk path on success. Removes the ``.tmp`` and
raises on cancellation or failure. The session is finished
(removed from the registry) exactly once, here callers do not
need to call ``DOWNLOAD_SERVER.finish`` themselves.
"""
final_path, tmp_path = resolve_destination(session.model_id)
os.makedirs(os.path.dirname(final_path), exist_ok=True)
# Wipe any stale .tmp from a previous failed attempt before we start —
# otherwise a partial body could masquerade as our completed download
# when the rename finally happens.
_remove_if_exists(tmp_path)
bytes_seen = 0
try:
http = await get_session()
headers = _auth_headers_for(session.url)
logging.info(
"[model_downloader] starting GET %s (auth=%s)",
session.url, "yes" if "Authorization" in headers else "no",
)
async with http.get(
session.url,
allow_redirects=True,
timeout=REQUEST_TIMEOUT,
headers=headers,
) as resp:
if resp.status != 200:
# Capture a snippet of the response body so 4xx/5xx aren't
# opaque in the logs — HF returns JSON or HTML with a
# human-readable reason on failures.
body_snippet = await _read_short(resp)
logging.warning(
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
session.url, resp.status, str(resp.url), body_snippet,
)
raise DownloadError(
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
status=resp.status,
)
total = parse_content_length(resp.headers.get("Content-Length"))
DOWNLOAD_SERVER.update_progress(session, 0, total)
with open(tmp_path, "wb") as f:
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
# Cancellation check between chunks. Cheap and means
# cancellation latency is bounded by one chunk plus
# one ``write()`` — typically well under a second
# even on slow disks.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
f.write(chunk)
bytes_seen += len(chunk)
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
# Final cancellation check before we promote the .tmp to the real
# filename — avoids the awkward case where cancel arrives during
# the very last chunk and we'd otherwise commit anyway.
if not DOWNLOAD_SERVER.is_active(session):
raise DownloadCancelled()
# Atomic rename. os.replace is atomic within the same filesystem,
# which is guaranteed here because tmp lives alongside final_path.
os.replace(tmp_path, final_path)
logging.info(
"[model_downloader] downloaded %s (%d bytes) from %s",
session.model_id, bytes_seen, session.url,
)
return final_path
except DownloadCancelled:
logging.info("[model_downloader] cancelled: %s", session.model_id)
_remove_if_exists(tmp_path)
raise
except Exception as e:
logging.warning(
"[model_downloader] failed: %s from %s: %s: %s",
session.model_id, session.url, type(e).__name__, e,
exc_info=True,
)
_remove_if_exists(tmp_path)
raise
finally:
# In all terminal states (success / cancel / error) drop the
# session from the registry. Idempotent — only removes if we're
# still the live epoch for this model_id.
DOWNLOAD_SERVER.finish(session)
class DownloadError(Exception):
"""Network / protocol error during a download."""
def __init__(self, message: str, status: Optional[int] = None) -> None:
super().__init__(message)
self.status = status
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
"""Read up to ``limit`` bytes of a response body for logging.
Used to surface the JSON/HTML reason from an HF non-2xx response in
server logs instead of just the status code. Best-effort: any
error here is swallowed.
"""
try:
raw = await resp.content.read(limit)
return raw.decode("utf-8", errors="replace").strip()
except Exception:
return "<unreadable>"
def _auth_headers_for(url: str) -> dict[str, str]:
"""Return any auth headers we should add to the GET for ``url``.
For HuggingFace URLs we inject the user's OAuth access token as a
Bearer header this is HF's documented way to access gated repos
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
other host we send no extra headers; allowlisted public files
don't need them and we don't want to leak tokens to other hosts.
"""
if not is_hf_url(url):
return {}
tok = HF_AUTH_STORE.get_token_sync()
if tok is None or not tok.access_token:
return {}
return {"Authorization": f"Bearer {tok.access_token}"}
def _remove_if_exists(path: str) -> None:
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[model_downloader] could not remove %s: %s", path, e)
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
"""Run a list of sessions one after the other.
Each session is independent: a failure or cancellation of one does
not abort the rest. Cancellations are observable via the registry
*before* a given download starts, so a session that's been
pre-cancelled (cancel before the worker reached it) just gets skipped.
"""
for session in sessions:
# If the session got cancelled before its turn, skip without
# touching disk. This is what makes the per-request "sequential
# but cancellable" semantic work.
if not DOWNLOAD_SERVER.is_active(session):
DOWNLOAD_SERVER.finish(session)
continue
try:
await stream_to_disk(session)
except DownloadCancelled:
# Already logged + tmp removed inside stream_to_disk.
continue
except Exception:
# stream_to_disk already logged. Continue with the rest of the batch.
continue
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
"""Kick off ``run_batch_sequential`` on the running event loop.
Returned task is fire-and-forget; the API handler returns immediately
after scheduling and clients observe progress via the polling endpoints.
"""
return asyncio.create_task(run_batch_sequential(sessions))

View File

@ -0,0 +1,238 @@
"""Per-URL probes for the unified availability endpoint.
Three cached/derived facts per URL:
- ``is_gated`` intrinsic to the model; cached forever once known.
Determined by ``auth_check(repo_id, token=None)``:
``GatedRepoError`` True, success False.
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
call. For non-gated URLs this is trivially True
(no HF call needed). For gated URLs we run
``auth_check`` with the stored token each call.
- ``file_size`` intrinsic to the file. Cached forever once
determined (including ``None`` on transient
failure we don't retry). We only attempt the
HEAD when we already know the URL is downloadable
to us; that way a failed-because-gated probe
never lands as a cached ``None``.
Caches are per-process, in-memory; small, no eviction needed for the
workflow-scale (~tens of URLs). Concurrent calls for the same URL
deduplicate via per-URL ``asyncio.Lock``.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Optional
import aiohttp
from huggingface_hub import HfApi
from huggingface_hub.errors import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
)
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
from app.model_downloader.http_client import get_session, parse_content_length
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
@dataclass
class ProbeResult:
file_size: Optional[int]
is_hf_downloadable: Optional[bool]
# --- caches -------------------------------------------------------------- #
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
# model — never changes for a given URL.
_is_gated_cache: dict[str, bool] = {}
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
# was attempted and produced no answer. **Only populated when we knew
# the URL was downloadable to us at probe time** — so gated-without-
# access never lands a ``None`` here that we'd be stuck with after login.
_file_size_cache: dict[str, Optional[int]] = {}
# Per-URL locks for single-flight probes — when multiple polls arrive
# in the same tick for the same URL, exactly one of them runs the HF
# call and the others wait on the result.
_locks: dict[str, asyncio.Lock] = {}
def _lock_for(url: str) -> asyncio.Lock:
lock = _locks.get(url)
if lock is None:
lock = asyncio.Lock()
_locks[url] = lock
return lock
def clear_caches_for_tests() -> None:
"""Test-only: drop everything."""
_is_gated_cache.clear()
_file_size_cache.clear()
_locks.clear()
# --- public entrypoint --------------------------------------------------- #
async def probe_url(url: str) -> ProbeResult:
"""Return downloadability + size for one URL, using caches where safe."""
if not is_hf_url(url):
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
# Size we still cache so we don't HEAD on every poll.
size = await _get_or_probe_size(url, token=None)
return ProbeResult(file_size=size, is_hf_downloadable=None)
repo_id = repo_id_from_url(url)
if repo_id is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Determine intrinsic gating once.
gated = await _resolve_is_gated(url, repo_id)
if gated is None:
return ProbeResult(file_size=None, is_hf_downloadable=None)
# Compute current-token downloadability per call.
tok = HF_AUTH_STORE.get_token_sync()
token_str: Optional[str] = tok.access_token if tok else None
if not gated:
is_hf_downloadable: Optional[bool] = True
else:
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
if is_hf_downloadable is True:
size = await _get_or_probe_size(url, token=token_str)
else:
# Skip the HEAD entirely — would 401 and we'd be stuck with
# cached None that survives a later login.
size = None
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
# --- gated/auth probes --------------------------------------------------- #
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
"""Decide once whether ``repo_id`` is a gated repo."""
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
async with _lock_for(url):
cached = _is_gated_cache.get(url)
if cached is not None:
return cached
try:
await asyncio.to_thread(_auth_check_sync, repo_id, None)
_is_gated_cache[url] = False
return False
except GatedRepoError:
_is_gated_cache[url] = True
return True
except RepositoryNotFoundError:
# Repo doesn't exist publicly. Treat as gated — we can't
# serve it without auth, and an authenticated check might
# still succeed if it's a private repo the user can see.
_is_gated_cache[url] = True
return True
except (HfHubHTTPError, Exception) as e:
logging.debug(
"[hf_auth] is_gated probe failed for %s (will retry): %s",
repo_id, e,
)
return None # don't cache; retry next call
async def _auth_check_with_token(
repo_id: str, token: Optional[str]
) -> Optional[bool]:
"""Auth-check with the supplied token. True/False/None per outcome."""
try:
await asyncio.to_thread(_auth_check_sync, repo_id, token)
return True
except GatedRepoError:
return False
except RepositoryNotFoundError:
return False
except HfHubHTTPError as e:
# 401/403 covers org-SSO-required, revoked tokens, and similar —
# all of which mean "can't fetch right now" from the user's POV.
status = getattr(getattr(e, "response", None), "status_code", None)
if status in (401, 403):
return False
logging.debug(
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
)
return None
except Exception as e:
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
return None
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
HfApi().auth_check(repo_id, token=token)
# --- size probe ---------------------------------------------------------- #
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
"""Return the cached size or HEAD the URL once and cache the result."""
if url in _file_size_cache:
return _file_size_cache[url]
async with _lock_for(url):
if url in _file_size_cache:
return _file_size_cache[url]
size = await _probe_size_once(url, token=token)
_file_size_cache[url] = size
return size
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
"""HEAD the URL and return the file size in bytes, or None on failure.
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
The real file size lives in the non-standard ``X-Linked-Size`` header
on that 302 response (``Content-Length`` is the redirect-body length).
Disabling redirect-follow lets us read either header on the same
response:
- LFS files: 302 + ``X-Linked-Size``
- Small/non-LFS files: 200 + ``Content-Length``
"""
headers = {"Authorization": f"Bearer {token}"} if token else {}
try:
session = await get_session()
async with session.head(
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
) as resp:
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
if linked is not None:
return linked
if resp.status == 200:
return parse_content_length(resp.headers.get("Content-Length"))
return None
except (aiohttp.ClientError, TimeoutError, OSError):
return None
# Backward-compat shim so consumers that still import the old name keep
# building during the refactor; can be removed once routes are updated.
MetadataProbeResult = ProbeResult

View File

@ -0,0 +1,106 @@
"""In-memory token cache with lazy disk persistence + refresh.
Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask
``get_valid_token()``; the store transparently refreshes from disk
on first use, refreshes via the OAuth refresh_token if the cached
access_token is expired, and returns ``None`` if neither path works.
The refresh path imports ``oauth.refresh_access_token`` lazily to
avoid an import cycle (oauth needs the store to save tokens it
acquires).
"""
from __future__ import annotations
import asyncio
import logging
import threading
from typing import Optional
from app.model_downloader.hf_auth.token_store import (
Token,
delete_token,
load_token,
save_token,
)
class HfAuthStore:
def __init__(self) -> None:
self._lock = threading.Lock()
self._token: Optional[Token] = None
self._loaded_from_disk = False
def _ensure_loaded(self) -> None:
"""Read the disk token into memory on first access."""
if self._loaded_from_disk:
return
with self._lock:
if self._loaded_from_disk:
return
self._token = load_token()
self._loaded_from_disk = True
def has_token(self) -> bool:
"""Cheap check: is there any token in memory?
Does not attempt refresh; an expired-but-refreshable token still
counts as "logged in" from the user's perspective.
"""
self._ensure_loaded()
return self._token is not None
def set_token(self, token: Token) -> None:
"""Replace the in-memory token and persist to disk."""
with self._lock:
self._token = token
self._loaded_from_disk = True
save_token(token)
def clear(self) -> None:
"""Forget the token in memory and on disk (logout)."""
with self._lock:
self._token = None
self._loaded_from_disk = True
delete_token()
def get_token_sync(self) -> Optional[Token]:
"""Return the cached token without refreshing.
Sync callers (e.g. constructing an Authorization header in a
non-async path) use this. They accept an expired token over
``None``; HF will simply return 401 and the caller can decide
what to do.
"""
self._ensure_loaded()
return self._token
async def get_valid_token(self) -> Optional[Token]:
"""Return a fresh token, refreshing via OAuth if necessary.
Returns ``None`` if there's no cached token at all, or if the
cached token is expired and refresh failed. Callers should
treat that as "user is not logged in".
"""
self._ensure_loaded()
tok = self._token
if tok is None:
return None
if tok.is_valid():
return tok
if not tok.refresh_token:
return None
# Lazy import to avoid the oauth ↔ store import cycle.
from app.model_downloader.hf_auth.oauth import refresh_access_token
try:
refreshed = await refresh_access_token(tok.refresh_token)
except Exception as e:
logging.warning("[hf_auth] token refresh failed: %s", e)
return None
self.set_token(refreshed)
return refreshed
HF_AUTH_STORE = HfAuthStore()

View File

@ -0,0 +1,55 @@
"""Whether this deployment is allowed to do interactive HF OAuth.
We only let the server hold a HuggingFace token under a strict trust
assumption: this is a *single tenant local* install. Concretely:
- The server is bound to a loopback address. SSH tunneling /
reverse-proxies can defeat this, but it's the strongest signal
we have without an authentication system.
- ``--multi-user`` is off. A shared token used implicitly by multiple
declared users would be a footgun one user's gated downloads
would silently authenticate as another.
Anything else and the frontend hides the HF login UI entirely; gated
models continue to show the "acquire it manually" message.
"""
from __future__ import annotations
import ipaddress
import socket
from comfy.cli_args import args
def _is_loopback(host: str | None) -> bool:
"""Duplicates ``server.is_loopback`` (small, no shared module yet).
Resolves a host or IP literal to whether it lives on the loopback
interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for
``0.0.0.0`` / ``::`` because those are bind-all wildcards, not
loopback.
"""
if host is None:
return False
try:
return ipaddress.ip_address(host).is_loopback
except ValueError:
pass
loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for _family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
loopback = True
except socket.gaierror:
pass
return loopback
def is_hf_auth_eligible() -> bool:
"""True iff this deployment may surface the HF OAuth flow."""
return _is_loopback(args.listen) and not args.multi_user

View File

@ -0,0 +1,277 @@
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
Wired so that ``POST /api/hf-auth-login-start`` can:
1. Generate state + PKCE verifier/challenge in this process.
2. Spin up a short-lived loopback HTTP server at port 41954 to
receive the redirect callback from HF.
3. Return the ``authorize_url`` for the frontend to open in a new tab.
After the user grants consent on huggingface.co, HF redirects to the
local callback URL with ``code`` and ``state``. The callback server
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
itself down.
Before this can be exercised end-to-end a maintainer must register a
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
below. See the comment above the constant for the exact steps.
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
import secrets
import threading
import time
import aiohttp
from aiohttp import web
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
from app.model_downloader.hf_auth.token_store import Token
from app.model_downloader.http_client import ssl_context
# --- HF OAuth app registration -------------------------------------------- #
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
# under a Comfy-Org-controlled HF account and substitute its client_id here.
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
# ("HuggingFace OAuth app setup" section). Short version:
# 1. huggingface.co → Settings → Connected Apps → "Create app"
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
# ``gated-repos`` (Repository Access). Leave everything else off.
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
# change ``CALLBACK_PORT``.
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
# The client_id is not a secret (it travels through the user's browser in
# plaintext); HF's "Public app" type means there's no client secret to
# manage — PKCE replaces it.
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
CALLBACK_HOST = "127.0.0.1"
CALLBACK_PORT = 41954
CALLBACK_PATH = "/api/auth/huggingface/callback"
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
TOKEN_URL = "https://huggingface.co/oauth/token"
# Minimal scope set for the feature:
# - openid : required by HF when the app uses OIDC at all
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
# settings UI; cosmetic but expected
# - gated-repos : grants the token enough to call ``auth_check`` and
# download files from public gated repos the user has
# accepted the license for. The wider ``read-repos`` scope
# would also work (it includes ``gated-repos``) but it
# additionally grants private-repo read access, which we
# don't need and which makes the consent screen scarier
# for the user.
SCOPE = "openid profile gated-repos"
# Maximum time the callback server stays up waiting for the user to
# complete consent on huggingface.co. Past this, the port closes and
# the user has to click "Log in" again.
CALLBACK_TIMEOUT_SECS = 300
# Process-wide lock so two simultaneous /api/hf-auth-login-start
# requests don't fight over port CALLBACK_PORT.
_OAUTH_LOCK = threading.Lock()
class OAuthInProgressError(Exception):
"""Another OAuth attempt is already running."""
class OAuthCallbackError(Exception):
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
# --- PKCE primitives ------------------------------------------------------ #
def _make_pkce() -> tuple[str, str, str]:
"""Return ``(verifier, challenge, state)``.
Verifier never leaves this process. Challenge and state travel
through the user's browser. State is checked on the callback to
prevent a malicious cross-origin redirect from injecting a token.
"""
verifier = secrets.token_urlsafe(64)
challenge = (
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
.rstrip(b"=")
.decode("ascii")
)
state = secrets.token_urlsafe(32)
return verifier, challenge, state
def _build_authorize_url(challenge: str, state: str) -> str:
from urllib.parse import urlencode
params = {
"client_id": HF_CLIENT_ID,
"redirect_uri": REDIRECT_URI,
"response_type": "code",
"scope": SCOPE,
"state": state,
"code_challenge": challenge,
"code_challenge_method": "S256",
}
return f"{AUTHORIZE_URL}?{urlencode(params)}"
# --- Token exchange ------------------------------------------------------- #
async def _exchange_code(code: str, verifier: str) -> Token:
"""Trade the authorization code for an access+refresh token pair."""
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": REDIRECT_URI,
"client_id": HF_CLIENT_ID,
"code_verifier": verifier,
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
resp.raise_for_status()
body = await resp.json()
return Token(
access_token=body["access_token"],
refresh_token=body.get("refresh_token"),
expires_at=time.time() + float(body.get("expires_in", 3600)),
scope=body.get("scope", SCOPE),
)
async def refresh_access_token(refresh_token: str) -> Token:
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": HF_CLIENT_ID,
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
resp.raise_for_status()
body = await resp.json()
return Token(
access_token=body["access_token"],
# If HF doesn't rotate refresh tokens, keep using the existing one.
refresh_token=body.get("refresh_token", refresh_token),
expires_at=time.time() + float(body.get("expires_in", 3600)),
scope=body.get("scope", SCOPE),
)
# --- Callback server ------------------------------------------------------ #
async def start_login_flow() -> str:
"""Begin one OAuth attempt: spawn the callback server, return the URL.
Returns the URL the frontend should open in a new tab. Raises
``OAuthInProgressError`` if another attempt is already running.
The callback server runs in the background until the user
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
either way the lock + port are released afterward.
"""
if not _OAUTH_LOCK.acquire(blocking=False):
raise OAuthInProgressError()
verifier, challenge, state = _make_pkce()
authorize_url = _build_authorize_url(challenge, state)
# Fire the callback server on the running loop and return.
asyncio.create_task(_run_callback_server(verifier, state))
return authorize_url
async def _run_callback_server(verifier: str, expected_state: str) -> None:
"""Listen for HF's redirect once, capture the token, then shut down."""
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
async def handler(request: web.Request) -> web.Response:
try:
if request.query.get("state") != expected_state:
return web.Response(status=400, text="state mismatch")
err = request.query.get("error")
if err:
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
return web.Response(status=400, text=f"OAuth error: {err}")
code = request.query.get("code")
if not code:
return web.Response(status=400, text="missing code")
tok = await _exchange_code(code, verifier)
if not received.done():
received.set_result(tok)
return web.Response(
content_type="text/html",
text=(
"<html><body style='font-family:sans-serif;padding:40px'>"
"<h2>HuggingFace login successful</h2>"
"<p>You can close this tab and return to ComfyUI.</p>"
"</body></html>"
),
)
except Exception as exc:
if not received.done():
received.set_exception(exc)
return web.Response(status=500, text=str(exc))
app = web.Application()
app.router.add_get(CALLBACK_PATH, handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
try:
await site.start()
except OSError as e:
# Port already in use (or some other socket-bind failure). Release
# the lock so a future attempt has a chance to succeed.
logging.warning("[hf_auth] could not bind callback port: %s", e)
_OAUTH_LOCK.release()
return
try:
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
except asyncio.TimeoutError:
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
return
except OAuthCallbackError as e:
logging.warning("[hf_auth] OAuth callback error: %s", e)
return
except Exception as e:
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
return
else:
HF_AUTH_STORE.set_token(token)
logging.info("[hf_auth] OAuth login complete")
finally:
await runner.cleanup()
if _OAUTH_LOCK.locked():
_OAUTH_LOCK.release()
def is_login_in_progress() -> bool:
"""True iff a callback server is currently bound + waiting."""
return _OAUTH_LOCK.locked()
# Re-export for callers that only want the URL builder (e.g. tests).
__all__ = [
"start_login_flow",
"refresh_access_token",
"is_login_in_progress",
"OAuthInProgressError",
"CALLBACK_TIMEOUT_SECS",
]

View File

@ -0,0 +1,89 @@
"""On-disk persistence for the HuggingFace OAuth token.
The token shape mirrors what HF returns on the token exchange: an
``access_token``, an optional ``refresh_token``, the absolute epoch at
which the access token expires, and the granted scope. We persist
this so logging in once survives ComfyUI restarts; the file is mode
``0600`` so only the OS user can read it.
"""
from __future__ import annotations
import json
import logging
import os
import stat
import time
from dataclasses import asdict, dataclass
from typing import Optional
import folder_paths
# Treat a token as expired this many seconds before its server-reported
# ``expires_at`` so we don't try to use a token mid-request only for it
# to flip stale between auth_check and the actual GET.
EXPIRY_BUFFER_SECS = 60
TOKEN_FILENAME = "hf_auth_token.json"
@dataclass
class Token:
"""One OAuth token + the metadata we need to use it."""
access_token: str
refresh_token: Optional[str]
expires_at: float # absolute epoch seconds
scope: str = ""
def is_valid(self) -> bool:
"""True iff we can use this token right now."""
return (
bool(self.access_token)
and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS)
)
def _token_path() -> str:
base = folder_paths.get_user_directory()
return os.path.join(base, TOKEN_FILENAME)
def load_token() -> Optional[Token]:
"""Read the persisted token, returning ``None`` if absent or corrupt."""
path = _token_path()
if not os.path.exists(path):
return None
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return Token(**data)
except (OSError, json.JSONDecodeError, TypeError) as e:
logging.warning("[hf_auth] could not load token at %s: %s", path, e)
return None
def save_token(token: Token) -> None:
"""Atomically write the token with 0600 permissions."""
path = _token_path()
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = path + ".tmp"
with open(tmp, "w", encoding="utf-8") as f:
json.dump(asdict(token), f)
os.replace(tmp, path)
try:
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
except OSError as e:
# On Windows / weird filesystems chmod may be a no-op; not fatal.
logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e)
def delete_token() -> None:
"""Remove the persisted token; no-op if it doesn't exist."""
path = _token_path()
try:
os.remove(path)
except FileNotFoundError:
pass
except OSError as e:
logging.warning("[hf_auth] could not remove token at %s: %s", path, e)

View File

@ -0,0 +1,41 @@
"""Parsers for the ``huggingface.co`` URL shape we accept in workflows.
The download API accepts URLs of the form
``https://huggingface.co/<org>/<repo>/resolve/<rev>/<path/to/file>``.
We need to recover ``<org>/<repo>`` (the *repo_id*) from such URLs for
``huggingface_hub`` API calls (notably ``HfApi.auth_check``).
"""
from __future__ import annotations
from typing import Optional
from urllib.parse import urlparse
_HF_HOST = "huggingface.co"
def is_hf_url(url: str) -> bool:
"""Cheap host check — does this URL point at huggingface.co?"""
try:
return urlparse(url).hostname == _HF_HOST
except ValueError:
return False
def repo_id_from_url(url: str) -> Optional[str]:
"""Extract ``<org>/<repo>`` from an HF model file URL.
Returns ``None`` if the URL isn't on huggingface.co or doesn't look
like a model-file URL. The expected shape is
``/<org>/<repo>/resolve/<rev>/<path>`` anything else
(datasets, spaces, /tree/, /blob/, ) we treat as out of scope here.
"""
if not is_hf_url(url):
return None
parts = urlparse(url).path.lstrip("/").split("/")
if len(parts) < 4 or parts[2] != "resolve":
return None
org, repo = parts[0], parts[1]
if not org or not repo:
return None
return f"{org}/{repo}"

View File

@ -0,0 +1,63 @@
"""Lazy module-level aiohttp ClientSession.
A single shared session means TLS handshakes are reused across HEAD probes
and the subsequent GETs to the same host (HuggingFace is the dominant
case), which is a noticeable speedup on cold connections.
We deliberately don't close the session at process exit — aiohttp's
warning about unclosed sessions is benign at shutdown, and adding atexit
plumbing buys nothing because the OS reclaims the sockets anyway. The
session lifetime is the lifetime of the Python process.
"""
from __future__ import annotations
import asyncio
import ssl
from typing import Optional
import aiohttp
import certifi
# Larger per-host pool than aiohttp's default (=100 total / =0 per host)
# so concurrent gated probes + a download to the same host don't queue.
_CONNECTOR_LIMIT_PER_HOST = 8
_session: Optional[aiohttp.ClientSession] = None
_lock = asyncio.Lock()
def ssl_context() -> ssl.SSLContext:
"""TLS context pinned to certifi's CA bundle.
aiohttp's default context uses the OS trust store, which isn't wired up
on some Python installs (python.org macOS, slim containers) there TLS
to huggingface.co fails with CERTIFICATE_VERIFY_FAILED.
"""
return ssl.create_default_context(cafile=certifi.where())
async def get_session() -> aiohttp.ClientSession:
"""Return the shared session, creating it on first call."""
global _session
if _session is not None and not _session.closed:
return _session
async with _lock:
if _session is None or _session.closed:
connector = aiohttp.TCPConnector(
limit_per_host=_CONNECTOR_LIMIT_PER_HOST,
ssl=ssl_context(),
)
_session = aiohttp.ClientSession(connector=connector)
return _session
def parse_content_length(value: Optional[str]) -> Optional[int]:
"""Parse a byte-count header value, or None if absent/malformed/negative."""
if not value:
return None
try:
n = int(value)
except ValueError:
return None
return n if n >= 0 else None

View File

@ -0,0 +1,93 @@
"""Path resolution for model downloads.
Model identifiers used across the download API are *relative destination
paths* of the form ``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``).
This module turns one of those identifiers into an absolute on-disk path
under one of ComfyUI's registered model folders, while rejecting unknown
folders, path traversal, and other ill-formed inputs.
"""
import os
import re
from typing import Optional, Tuple
import folder_paths
# Constrain components so a model_id can never escape its target directory.
# - directory: a single path segment of safe chars
# - filename: a single path segment of safe chars, must end with a model extension
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$")
class InvalidModelId(ValueError):
"""Raised when a model_id is syntactically invalid or refers to an
unknown model folder."""
def parse_model_id(model_id: str) -> Tuple[str, str]:
"""Split ``<directory>/<filename>`` and validate both components.
Returns ``(directory, filename)``. Raises ``InvalidModelId`` on
malformed input. Does NOT touch the filesystem.
"""
if not isinstance(model_id, str) or "/" not in model_id:
raise InvalidModelId(f"model_id must be '<directory>/<filename>', got {model_id!r}")
directory, _, filename = model_id.partition("/")
if "/" in filename or not directory or not filename:
raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}")
if not _SEGMENT_RE.match(directory):
raise InvalidModelId(f"invalid directory segment {directory!r}")
if not _SEGMENT_RE.match(filename):
raise InvalidModelId(f"invalid filename segment {filename!r}")
if directory not in folder_paths.folder_names_and_paths:
raise InvalidModelId(f"unknown model folder {directory!r}")
return directory, filename
def resolve_existing(model_id: str) -> Optional[str]:
"""Return the absolute path of an installed model, or None if missing.
Honours ``extra_model_paths.yaml`` transparently via
``folder_paths.get_full_path``.
"""
directory, filename = parse_model_id(model_id)
return folder_paths.get_full_path(directory, filename)
def resolve_destination(model_id: str) -> Tuple[str, str]:
"""Return ``(final_path, tmp_path)`` for a download.
Downloads land at the first registered path for the model's directory
(the "primary" location). The ``.tmp`` sibling is used as the write
target and atomically renamed on success.
"""
directory, filename = parse_model_id(model_id)
roots = folder_paths.get_folder_paths(directory)
if not roots:
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
root = roots[0]
final_path = os.path.join(root, filename)
tmp_path = final_path + ".tmp"
return final_path, tmp_path
def iter_all_tmp_paths():
"""Yield every ``*.tmp`` file under every registered model folder.
Used at startup to sweep orphans left by crashed/restarted downloads.
"""
seen_roots: set[str] = set()
for directory in folder_paths.folder_names_and_paths.keys():
for root in folder_paths.get_folder_paths(directory):
if root in seen_roots or not os.path.isdir(root):
continue
seen_roots.add(root)
try:
for entry in os.scandir(root):
if entry.is_file() and entry.name.endswith(".tmp"):
yield entry.path
except OSError:
# Folder might be unreadable / missing on certain mounts —
# not fatal, just skip it.
continue