mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
feat(model_downloader): server-side model download + HuggingFace OAuth subsystem
Self-contained package under app/model_downloader/: - Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension). - Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep. - Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll. - HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh. - Pydantic request/response schemas and aiohttp routes under api/. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
4676481609
commit
fdd84d04a0
46
app/model_downloader/allowlist.py
Normal file
46
app/model_downloader/allowlist.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""URL allowlist for server-side model fetches.
|
||||
|
||||
Mirrors the frontend's ``isModelDownloadable`` allowlist so the two flows
|
||||
agree on which URLs are eligible for download. Server-side allowlisting is
|
||||
the primary SSRF defense for this subsystem — workflow JSON is untrusted
|
||||
input (anyone can hand-craft one), so we never let the server fetch URLs
|
||||
outside this list.
|
||||
"""
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Frontend parity: ``missingModelDownload-*.js`` exports the same triple as
|
||||
# ``i = [...]`` (Civitai / HuggingFace / localhost).
|
||||
_ALLOWED_URL_PREFIXES = (
|
||||
"https://huggingface.co/",
|
||||
"https://civitai.com/",
|
||||
"http://localhost:",
|
||||
"http://127.0.0.1:",
|
||||
)
|
||||
|
||||
# Frontend parity: same set as ``a = [...]`` in the bundle.
|
||||
_ALLOWED_MODEL_EXTENSIONS = (
|
||||
".safetensors",
|
||||
".sft",
|
||||
".ckpt",
|
||||
".pth",
|
||||
".pt",
|
||||
)
|
||||
|
||||
|
||||
def is_url_allowed(url: str) -> bool:
|
||||
"""Check whether ``url`` is permitted as a server-side download source.
|
||||
|
||||
Returns True only when both:
|
||||
- the URL starts with one of the allowed prefixes, AND
|
||||
- the URL's final path segment ends with a known model extension.
|
||||
|
||||
Both checks are required to keep arbitrary HTML / API endpoints on
|
||||
allowlisted hosts (e.g. ``https://huggingface.co/api/...``) off the table.
|
||||
"""
|
||||
if not isinstance(url, str) or not url:
|
||||
return False
|
||||
if not any(url.startswith(p) for p in _ALLOWED_URL_PREFIXES):
|
||||
return False
|
||||
path = urlparse(url).path
|
||||
return any(path.endswith(ext) for ext in _ALLOWED_MODEL_EXTENSIONS)
|
||||
332
app/model_downloader/api/routes.py
Normal file
332
app/model_downloader/api/routes.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Aiohttp routes for the server-side model download subsystem.
|
||||
|
||||
Endpoint surface (all under ``/api/``, all kebab-case):
|
||||
|
||||
- ``POST /api/models-availability-status`` — bulk status + metadata query.
|
||||
- ``POST /api/download-models`` — start a batch of downloads.
|
||||
- ``POST /api/cancel-model-download-session`` — cancel a single in-flight one.
|
||||
- ``GET /api/hf-auth-token-status`` — current HF login state.
|
||||
- ``POST /api/hf-auth-login-start`` — begin the HF OAuth flow.
|
||||
- ``POST /api/hf-auth-logout`` — drop the stored HF token.
|
||||
|
||||
The contract is intentionally narrow: only model_ids of the form
|
||||
``<directory>/<filename>`` (validated via ``app.model_downloader.paths``)
|
||||
are accepted, and only URLs on the same allowlist the frontend already
|
||||
uses (HuggingFace, Civitai, localhost) can be fetched. Both are required
|
||||
to keep the server out of the SSRF business for this feature.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from aiohttp import web
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.model_downloader.allowlist import is_url_allowed
|
||||
from app.model_downloader.download_server import (
|
||||
DOWNLOAD_SERVER,
|
||||
DownloadSession,
|
||||
)
|
||||
from app.model_downloader.downloader import schedule_batch
|
||||
from app.model_downloader.gated_detection import probe_url
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_auth.eligibility import is_hf_auth_eligible
|
||||
from app.model_downloader.hf_auth.oauth import (
|
||||
OAuthInProgressError,
|
||||
start_login_flow,
|
||||
)
|
||||
from app.model_downloader.paths import (
|
||||
InvalidModelId,
|
||||
parse_model_id,
|
||||
resolve_existing,
|
||||
)
|
||||
from app.model_downloader.api import schemas_in, schemas_out
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
def register_routes(app: web.Application) -> None:
|
||||
"""Wire the model-downloader routes into the running aiohttp app.
|
||||
|
||||
Called once from ``server.py`` during ``PromptServer`` startup.
|
||||
"""
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
# ----- response helpers (same envelope as app/assets/api/routes.py) -----
|
||||
|
||||
|
||||
def _error(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"code": code, "message": message, "details": details or {}}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _validation_error(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error(400, code, "Validation failed.", {"errors": json.loads(ve.json())})
|
||||
|
||||
|
||||
def _ok(payload: BaseModel, status: int = 200) -> web.Response:
|
||||
return web.json_response(
|
||||
payload.model_dump(mode="json", exclude_none=False),
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
async def _parse_body(request: web.Request, model: type[BaseModel]) -> Any:
|
||||
"""Parse a JSON body into a pydantic model or raise a 400 response."""
|
||||
try:
|
||||
raw = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return _error(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
try:
|
||||
return model.model_validate(raw)
|
||||
except ValidationError as ve:
|
||||
return _validation_error("INVALID_BODY", ve)
|
||||
|
||||
|
||||
# ----- 1. availability status (unified: state + metadata per id) -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/models-availability-status")
|
||||
async def models_availability_status(request: web.Request) -> web.Response:
|
||||
"""Return per-id ``{state, progress, file_size, is_hf_downloadable}``.
|
||||
|
||||
State (``available`` / ``missing`` / ``downloading``) is cheap to
|
||||
recompute per call. ``file_size`` and ``is_gated`` are cached
|
||||
server-side per URL. ``is_hf_downloadable`` is recomputed every
|
||||
call from the current token state — that's what makes login + license
|
||||
acceptance show up in the UI within one poll cycle without any
|
||||
frontend cache plumbing.
|
||||
"""
|
||||
parsed = await _parse_body(request, schemas_in.AvailabilityStatusRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
|
||||
items = list(parsed.models.items())
|
||||
|
||||
# Run all probes concurrently; each is internally cached per URL.
|
||||
probes = await asyncio.gather(*(probe_url(url) for _, url in items))
|
||||
|
||||
response_models: dict[str, schemas_out.ModelStatusEntry] = {}
|
||||
for (model_id, _url), probe in zip(items, probes):
|
||||
try:
|
||||
parse_model_id(model_id)
|
||||
except InvalidModelId:
|
||||
# Ill-formed identifier: report as missing without 400-ing the
|
||||
# whole batch — the workflow author probably typo'd.
|
||||
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||
state="missing",
|
||||
file_size=probe.file_size,
|
||||
is_hf_downloadable=probe.is_hf_downloadable,
|
||||
)
|
||||
continue
|
||||
|
||||
active = DOWNLOAD_SERVER.get(model_id)
|
||||
if active is not None:
|
||||
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||
state="downloading",
|
||||
progress=schemas_out.DownloadProgress(
|
||||
bytes_downloaded=active.bytes_downloaded,
|
||||
total_bytes=active.total_bytes,
|
||||
progress=active.progress,
|
||||
),
|
||||
file_size=probe.file_size,
|
||||
is_hf_downloadable=probe.is_hf_downloadable,
|
||||
)
|
||||
continue
|
||||
|
||||
state: schemas_out.ModelState = (
|
||||
"available" if resolve_existing(model_id) is not None else "missing"
|
||||
)
|
||||
response_models[model_id] = schemas_out.ModelStatusEntry(
|
||||
state=state,
|
||||
file_size=probe.file_size,
|
||||
is_hf_downloadable=probe.is_hf_downloadable,
|
||||
)
|
||||
|
||||
return _ok(schemas_out.AvailabilityStatusResponse(
|
||||
models=response_models,
|
||||
hf_auth=schemas_out.HfAuthStatus(
|
||||
token_available=HF_AUTH_STORE.has_token(),
|
||||
eligible=is_hf_auth_eligible(),
|
||||
),
|
||||
))
|
||||
|
||||
|
||||
# ----- 2. start downloads -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/download-models")
|
||||
async def download_models(request: web.Request) -> web.Response:
|
||||
parsed = await _parse_body(request, schemas_in.DownloadModelsRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
|
||||
if not parsed.models:
|
||||
return _error(400, "EMPTY_REQUEST", "No models supplied.")
|
||||
|
||||
# ----- precondition pass: validate everything BEFORE registering anything -----
|
||||
# Atomic semantics: if any model fails any precondition (invalid id,
|
||||
# not allow-listed URL, already on disk, already downloading, or gated),
|
||||
# the entire request fails and no state is changed.
|
||||
requested = list(parsed.models.items())
|
||||
|
||||
for model_id, url in requested:
|
||||
try:
|
||||
parse_model_id(model_id)
|
||||
except InvalidModelId as e:
|
||||
return _error(400, "INVALID_MODEL_ID", str(e),
|
||||
{"model_id": model_id})
|
||||
|
||||
if not is_url_allowed(url):
|
||||
return _error(
|
||||
400, "URL_NOT_ALLOWED",
|
||||
"Server-side downloads only accept HuggingFace, Civitai, "
|
||||
"or localhost URLs ending in a known model extension.",
|
||||
{"model_id": model_id, "url": url},
|
||||
)
|
||||
|
||||
if resolve_existing(model_id) is not None:
|
||||
return _error(409, "ALREADY_AVAILABLE",
|
||||
f"Model already exists on disk: {model_id}",
|
||||
{"model_id": model_id})
|
||||
|
||||
if DOWNLOAD_SERVER.is_downloading(model_id):
|
||||
return _error(409, "ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress.",
|
||||
{"model_id": model_id})
|
||||
|
||||
# Reachability check last — it's the only one that talks to the
|
||||
# network. Concurrent probes. For HF URLs ``is_hf_downloadable``
|
||||
# reflects current token access; for non-HF URLs it's None, and we
|
||||
# treat that as "no info, proceed".
|
||||
probes = await asyncio.gather(*(probe_url(url) for _, url in requested))
|
||||
for (model_id, url), probe in zip(requested, probes):
|
||||
if probe.is_hf_downloadable is False:
|
||||
return _error(
|
||||
400, "MODEL_NOT_DOWNLOADABLE",
|
||||
f"Model {model_id} is gated on HuggingFace and the current "
|
||||
f"server token (if any) does not grant access.",
|
||||
{"model_id": model_id, "url": url},
|
||||
)
|
||||
|
||||
# ----- registration pass: try_register is atomic per model_id -----
|
||||
# Defensive: another request might have raced past our pre-check
|
||||
# between the loop above and here. try_register handles that.
|
||||
sessions: list[DownloadSession] = []
|
||||
for model_id, url in requested:
|
||||
session = DOWNLOAD_SERVER.try_register(model_id, url)
|
||||
if session is None:
|
||||
# Race: someone else got in. Roll back what we registered.
|
||||
for s in sessions:
|
||||
DOWNLOAD_SERVER.cancel(s.model_id)
|
||||
return _error(409, "ALREADY_DOWNLOADING",
|
||||
f"A download for {model_id} is already in progress (race).",
|
||||
{"model_id": model_id})
|
||||
sessions.append(session)
|
||||
|
||||
DOWNLOAD_SERVER.sweep_orphan_tmp_files()
|
||||
schedule_batch(sessions)
|
||||
logging.info(
|
||||
"[model_downloader] scheduled %d downloads: %s",
|
||||
len(sessions), [s.model_id for s in sessions],
|
||||
)
|
||||
|
||||
return _ok(schemas_out.DownloadModelsResponse(
|
||||
accepted=True,
|
||||
scheduled=[s.model_id for s in sessions],
|
||||
), status=202)
|
||||
|
||||
|
||||
# ----- 3. cancel a session -----
|
||||
|
||||
|
||||
@ROUTES.post("/api/cancel-model-download-session")
|
||||
async def cancel_model_download_session(request: web.Request) -> web.Response:
|
||||
parsed = await _parse_body(request, schemas_in.CancelDownloadSessionRequest)
|
||||
if isinstance(parsed, web.Response):
|
||||
return parsed
|
||||
|
||||
cancelled = DOWNLOAD_SERVER.cancel(parsed.model_id)
|
||||
if not cancelled:
|
||||
return _error(404, "NOT_DOWNLOADING",
|
||||
f"No active download for {parsed.model_id}.",
|
||||
{"model_id": parsed.model_id})
|
||||
|
||||
return _ok(schemas_out.CancelDownloadSessionResponse(cancelled=True))
|
||||
|
||||
|
||||
# ----- 4. HuggingFace OAuth status / login start / logout -----
|
||||
|
||||
|
||||
@ROUTES.get("/api/hf-auth-token-status")
|
||||
async def hf_auth_token_status(request: web.Request) -> web.Response:
|
||||
"""Return whether the server holds a usable HF token + its username.
|
||||
|
||||
Used by the settings UI and (out-of-band) by the frontend on
|
||||
login completion. ``token_available`` is true even if the cached
|
||||
access_token is expired — as long as a refresh_token exists, the
|
||||
user is "logged in" from their perspective.
|
||||
"""
|
||||
token_present = HF_AUTH_STORE.has_token()
|
||||
username: Optional[str] = None
|
||||
if token_present:
|
||||
# Resolve the username via whoami. Done in a worker thread because
|
||||
# huggingface_hub's whoami is synchronous + blocks on a network call.
|
||||
tok = await HF_AUTH_STORE.get_valid_token()
|
||||
if tok is not None:
|
||||
try:
|
||||
username = await asyncio.to_thread(_whoami_username, tok.access_token)
|
||||
except Exception as e:
|
||||
logging.debug("[hf_auth] whoami failed: %s", e)
|
||||
return _ok(schemas_out.HfAuthTokenStatusResponse(
|
||||
token_available=token_present,
|
||||
username=username,
|
||||
))
|
||||
|
||||
|
||||
def _whoami_username(token: str) -> Optional[str]:
|
||||
"""Sync helper: ask HF for the user name attached to a token."""
|
||||
from huggingface_hub import HfApi
|
||||
info = HfApi().whoami(token=token)
|
||||
if isinstance(info, dict):
|
||||
return info.get("name") or info.get("fullname")
|
||||
return None
|
||||
|
||||
|
||||
@ROUTES.post("/api/hf-auth-login-start")
|
||||
async def hf_auth_login_start(request: web.Request) -> web.Response:
|
||||
"""Begin one OAuth attempt: bind the callback port, return the URL.
|
||||
|
||||
Rejected outright if this deployment isn't eligible (we don't
|
||||
surface the option on multi-tenant / public-IP installs).
|
||||
"""
|
||||
if not is_hf_auth_eligible():
|
||||
return _error(
|
||||
403, "HF_AUTH_NOT_ELIGIBLE",
|
||||
"This server is not eligible for interactive HuggingFace login. "
|
||||
"It must be bound to a loopback address and not running in "
|
||||
"--multi-user mode.",
|
||||
)
|
||||
try:
|
||||
url = await start_login_flow()
|
||||
except OAuthInProgressError:
|
||||
return _error(
|
||||
409, "HF_AUTH_IN_PROGRESS",
|
||||
"Another HuggingFace login attempt is in progress. Try again "
|
||||
"after it completes or times out.",
|
||||
)
|
||||
return _ok(schemas_out.HfAuthLoginStartResponse(authorize_url=url))
|
||||
|
||||
|
||||
@ROUTES.post("/api/hf-auth-logout")
|
||||
async def hf_auth_logout(request: web.Request) -> web.Response:
|
||||
"""Drop the in-memory + on-disk HF token."""
|
||||
HF_AUTH_STORE.clear()
|
||||
return _ok(schemas_out.HfAuthLogoutResponse(logged_out=True))
|
||||
41
app/model_downloader/api/schemas_in.py
Normal file
41
app/model_downloader/api/schemas_in.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Request schemas for the model-downloader API.
|
||||
|
||||
Each endpoint accepts a small JSON body. Pydantic enforces the shape at
|
||||
the boundary; route handlers operate only on validated values past that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AvailabilityStatusRequest(BaseModel):
|
||||
"""``POST /api/models-availability-status``.
|
||||
|
||||
Sent by the frontend on each poll. Each entry is ``{model_id: url}``;
|
||||
the URL is the one declared in ``properties.models[i].url`` in the
|
||||
workflow JSON and lets the server compute per-id metadata
|
||||
(``file_size`` + ``is_hf_downloadable``) on the same request.
|
||||
"""
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DownloadModelsRequest(BaseModel):
|
||||
"""``POST /api/download-models``.
|
||||
|
||||
Same shape as the metadata request — the URL for each model_id.
|
||||
Returns immediately after validation and scheduling.
|
||||
"""
|
||||
models: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CancelDownloadSessionRequest(BaseModel):
|
||||
"""``POST /api/cancel-model-download-session``."""
|
||||
model_id: str
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AvailabilityStatusRequest",
|
||||
"DownloadModelsRequest",
|
||||
"CancelDownloadSessionRequest",
|
||||
]
|
||||
81
app/model_downloader/api/schemas_out.py
Normal file
81
app/model_downloader/api/schemas_out.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Response schemas for the model-downloader API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
ModelState = Literal["available", "missing", "downloading"]
|
||||
|
||||
|
||||
class DownloadProgress(BaseModel):
|
||||
"""Embedded in a model entry when its state is ``downloading``."""
|
||||
bytes_downloaded: int
|
||||
total_bytes: Optional[int] = None
|
||||
progress: Optional[float] = None # fraction in [0,1]; null until total known
|
||||
|
||||
|
||||
class ModelStatusEntry(BaseModel):
|
||||
"""Everything the UI needs to render one row, in one shot.
|
||||
|
||||
``state`` reflects what the server has on disk + in-flight; ``file_size``
|
||||
and ``is_hf_downloadable`` come from probes (intrinsic; cached).
|
||||
The HF fields are populated for every poll (cached on the server),
|
||||
so license-acceptance flips show up within one poll interval without
|
||||
any frontend cache invalidation.
|
||||
"""
|
||||
state: ModelState
|
||||
progress: Optional[DownloadProgress] = None
|
||||
file_size: Optional[int] = None
|
||||
# HF-only: True iff the server can fetch this URL with current auth
|
||||
# state. False iff gated and lacking access. None for non-HF URLs.
|
||||
is_hf_downloadable: Optional[bool] = None
|
||||
|
||||
|
||||
class HfAuthStatus(BaseModel):
|
||||
"""Snapshot of HF login state, embedded in availability response."""
|
||||
token_available: bool
|
||||
eligible: bool
|
||||
|
||||
|
||||
class AvailabilityStatusResponse(BaseModel):
|
||||
models: dict[str, ModelStatusEntry]
|
||||
hf_auth: HfAuthStatus
|
||||
|
||||
|
||||
class DownloadModelsResponse(BaseModel):
|
||||
accepted: bool
|
||||
scheduled: list[str]
|
||||
|
||||
|
||||
class CancelDownloadSessionResponse(BaseModel):
|
||||
cancelled: bool
|
||||
|
||||
|
||||
class HfAuthTokenStatusResponse(BaseModel):
|
||||
token_available: bool
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class HfAuthLoginStartResponse(BaseModel):
|
||||
authorize_url: str
|
||||
|
||||
|
||||
class HfAuthLogoutResponse(BaseModel):
|
||||
logged_out: bool
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ModelState",
|
||||
"DownloadProgress",
|
||||
"ModelStatusEntry",
|
||||
"HfAuthStatus",
|
||||
"AvailabilityStatusResponse",
|
||||
"DownloadModelsResponse",
|
||||
"CancelDownloadSessionResponse",
|
||||
"HfAuthTokenStatusResponse",
|
||||
"HfAuthLoginStartResponse",
|
||||
"HfAuthLogoutResponse",
|
||||
]
|
||||
179
app/model_downloader/download_server.py
Normal file
179
app/model_downloader/download_server.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""Process-wide registry of in-flight model downloads.
|
||||
|
||||
A single instance, ``DOWNLOAD_SERVER``, tracks every currently-running
|
||||
server-side model fetch. Designed to be safe with multiple concurrent
|
||||
clients hitting the API: each model_id has at most one active session,
|
||||
and the API rejects requests that conflict with in-flight downloads.
|
||||
|
||||
Cancellation is cooperative — the download loop checks ``is_active`` on
|
||||
its own session between chunks and raises ``DownloadCancelled`` when the
|
||||
session has been removed from the registry. This avoids the complications
|
||||
of ``Task.cancel()`` from outside the loop while still giving deterministic
|
||||
rollback semantics (the worker is responsible for deleting its own
|
||||
``.tmp`` on the cancel path).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from app.model_downloader.paths import iter_all_tmp_paths
|
||||
|
||||
|
||||
class DownloadCancelled(Exception):
|
||||
"""Raised by the streaming loop when its session has been removed
|
||||
from the registry (cancellation request) and the worker should roll
|
||||
back its ``.tmp`` file."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadSession:
|
||||
"""One in-flight download.
|
||||
|
||||
``progress`` is a fraction in ``[0.0, 1.0]``; ``None`` until the first
|
||||
byte arrives and we know whether the response carries a
|
||||
``Content-Length``. ``total_bytes`` mirrors that header when present.
|
||||
"""
|
||||
model_id: str
|
||||
url: str
|
||||
progress: Optional[float] = None
|
||||
bytes_downloaded: int = 0
|
||||
total_bytes: Optional[int] = None
|
||||
# Sequence number used solely as identity for the cancellation check —
|
||||
# so that "cancel + restart" doesn't get confused by stale workers.
|
||||
epoch: int = field(default_factory=lambda: 0)
|
||||
|
||||
|
||||
class DownloadServer:
|
||||
"""Singleton registry of active downloads.
|
||||
|
||||
All mutation goes through this object so concurrent route handlers
|
||||
see a consistent view. The ``_lock`` is a plain threading lock
|
||||
because the registry is consulted from both the asyncio event-loop
|
||||
thread (route handlers) and from any worker coroutines spawned to
|
||||
perform downloads.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: dict[str, DownloadSession] = {}
|
||||
self._epoch_counter = 0
|
||||
self._orphan_sweep_done = False
|
||||
|
||||
# ----- lifecycle -----
|
||||
|
||||
def sweep_orphan_tmp_files(self) -> None:
|
||||
"""Idempotently sweep ``*.tmp`` files left by crashed downloads.
|
||||
|
||||
Deferred off the import path so module load doesn't block on
|
||||
filesystem I/O against potentially-slow mounts. Each route handler
|
||||
that might create a new ``.tmp`` runs this exactly once.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._orphan_sweep_done:
|
||||
return
|
||||
self._orphan_sweep_done = True
|
||||
for path in iter_all_tmp_paths():
|
||||
try:
|
||||
os.remove(path)
|
||||
logging.info("[model_downloader] removed orphan tmp file: %s", path)
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
# ----- queries -----
|
||||
|
||||
def is_downloading(self, model_id: str) -> bool:
|
||||
with self._lock:
|
||||
return model_id in self._sessions
|
||||
|
||||
def get(self, model_id: str) -> Optional[DownloadSession]:
|
||||
with self._lock:
|
||||
return self._sessions.get(model_id)
|
||||
|
||||
def snapshot(self) -> dict[str, DownloadSession]:
|
||||
"""Return a shallow copy of the current sessions map."""
|
||||
with self._lock:
|
||||
return dict(self._sessions)
|
||||
|
||||
# ----- mutations -----
|
||||
|
||||
def try_register(self, model_id: str, url: str) -> Optional[DownloadSession]:
|
||||
"""Atomically register a new session iff none exists for ``model_id``.
|
||||
|
||||
Returns the new session on success, ``None`` if a session is already
|
||||
in flight. Callers must check the return value — the caller is the
|
||||
sole owner of the session it gets back.
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self._sessions:
|
||||
return None
|
||||
self._epoch_counter += 1
|
||||
session = DownloadSession(
|
||||
model_id=model_id,
|
||||
url=url,
|
||||
epoch=self._epoch_counter,
|
||||
)
|
||||
self._sessions[model_id] = session
|
||||
return session
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
session: DownloadSession,
|
||||
bytes_downloaded: int,
|
||||
total_bytes: Optional[int],
|
||||
) -> None:
|
||||
"""Update progress on a session. No-op if the session has been
|
||||
removed (cancelled) — caller should check ``is_active`` separately."""
|
||||
with self._lock:
|
||||
current = self._sessions.get(session.model_id)
|
||||
if current is None or current.epoch != session.epoch:
|
||||
return
|
||||
current.bytes_downloaded = bytes_downloaded
|
||||
current.total_bytes = total_bytes
|
||||
if total_bytes and total_bytes > 0:
|
||||
current.progress = min(1.0, bytes_downloaded / total_bytes)
|
||||
|
||||
def is_active(self, session: DownloadSession) -> bool:
|
||||
"""True iff this exact session is still the registered one for
|
||||
its model_id. False after cancellation, after completion, or if
|
||||
another session has replaced it."""
|
||||
with self._lock:
|
||||
current = self._sessions.get(session.model_id)
|
||||
return current is not None and current.epoch == session.epoch
|
||||
|
||||
def finish(self, session: DownloadSession) -> None:
|
||||
"""Remove a completed (or cancelled) session from the registry.
|
||||
|
||||
Safe to call multiple times. Only removes if the epoch matches —
|
||||
we never accidentally evict a *newer* session for the same model_id.
|
||||
"""
|
||||
with self._lock:
|
||||
current = self._sessions.get(session.model_id)
|
||||
if current is not None and current.epoch == session.epoch:
|
||||
del self._sessions[session.model_id]
|
||||
|
||||
def reset_for_tests(self) -> None:
|
||||
"""Clear all sessions and reset the epoch counter. Test-only."""
|
||||
with self._lock:
|
||||
self._sessions.clear()
|
||||
self._epoch_counter = 0
|
||||
|
||||
def cancel(self, model_id: str) -> bool:
|
||||
"""Remove the session registered for ``model_id``.
|
||||
|
||||
Returns True if there was an active session to cancel. The worker
|
||||
will discover the cancellation on its next ``is_active`` check
|
||||
and roll back its ``.tmp`` file.
|
||||
"""
|
||||
with self._lock:
|
||||
if model_id in self._sessions:
|
||||
del self._sessions[model_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
DOWNLOAD_SERVER = DownloadServer()
|
||||
205
app/model_downloader/downloader.py
Normal file
205
app/model_downloader/downloader.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Streaming download worker with progress reporting and cancellation.
|
||||
|
||||
Each download writes to ``<final_path>.tmp`` and atomically renames into
|
||||
place on success. Between chunks the worker checks the registry for
|
||||
cancellation (via ``DownloadServer.is_active``) and rolls back its
|
||||
``.tmp`` on cancel or on any error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from app.model_downloader.download_server import (
|
||||
DOWNLOAD_SERVER,
|
||||
DownloadCancelled,
|
||||
DownloadSession,
|
||||
)
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_url import is_hf_url
|
||||
from app.model_downloader.http_client import get_session, parse_content_length
|
||||
from app.model_downloader.paths import resolve_destination
|
||||
|
||||
|
||||
CHUNK_SIZE = 64 * 1024 # 64 KiB — same scale as other ComfyUI download paths.
|
||||
REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_connect=30, sock_read=120)
|
||||
|
||||
|
||||
async def stream_to_disk(session: DownloadSession) -> str:
|
||||
"""Run a single download to completion or cancellation.
|
||||
|
||||
Returns the final on-disk path on success. Removes the ``.tmp`` and
|
||||
raises on cancellation or failure. The session is finished
|
||||
(removed from the registry) exactly once, here — callers do not
|
||||
need to call ``DOWNLOAD_SERVER.finish`` themselves.
|
||||
"""
|
||||
final_path, tmp_path = resolve_destination(session.model_id)
|
||||
os.makedirs(os.path.dirname(final_path), exist_ok=True)
|
||||
|
||||
# Wipe any stale .tmp from a previous failed attempt before we start —
|
||||
# otherwise a partial body could masquerade as our completed download
|
||||
# when the rename finally happens.
|
||||
_remove_if_exists(tmp_path)
|
||||
|
||||
bytes_seen = 0
|
||||
try:
|
||||
http = await get_session()
|
||||
headers = _auth_headers_for(session.url)
|
||||
logging.info(
|
||||
"[model_downloader] starting GET %s (auth=%s)",
|
||||
session.url, "yes" if "Authorization" in headers else "no",
|
||||
)
|
||||
async with http.get(
|
||||
session.url,
|
||||
allow_redirects=True,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
# Capture a snippet of the response body so 4xx/5xx aren't
|
||||
# opaque in the logs — HF returns JSON or HTML with a
|
||||
# human-readable reason on failures.
|
||||
body_snippet = await _read_short(resp)
|
||||
logging.warning(
|
||||
"[model_downloader] GET %s failed: status=%d final_url=%s body=%s",
|
||||
session.url, resp.status, str(resp.url), body_snippet,
|
||||
)
|
||||
raise DownloadError(
|
||||
f"unexpected HTTP {resp.status} fetching {session.url}: {body_snippet}",
|
||||
status=resp.status,
|
||||
)
|
||||
|
||||
total = parse_content_length(resp.headers.get("Content-Length"))
|
||||
DOWNLOAD_SERVER.update_progress(session, 0, total)
|
||||
|
||||
with open(tmp_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
|
||||
# Cancellation check between chunks. Cheap and means
|
||||
# cancellation latency is bounded by one chunk plus
|
||||
# one ``write()`` — typically well under a second
|
||||
# even on slow disks.
|
||||
if not DOWNLOAD_SERVER.is_active(session):
|
||||
raise DownloadCancelled()
|
||||
f.write(chunk)
|
||||
bytes_seen += len(chunk)
|
||||
DOWNLOAD_SERVER.update_progress(session, bytes_seen, total)
|
||||
|
||||
# Final cancellation check before we promote the .tmp to the real
|
||||
# filename — avoids the awkward case where cancel arrives during
|
||||
# the very last chunk and we'd otherwise commit anyway.
|
||||
if not DOWNLOAD_SERVER.is_active(session):
|
||||
raise DownloadCancelled()
|
||||
|
||||
# Atomic rename. os.replace is atomic within the same filesystem,
|
||||
# which is guaranteed here because tmp lives alongside final_path.
|
||||
os.replace(tmp_path, final_path)
|
||||
logging.info(
|
||||
"[model_downloader] downloaded %s (%d bytes) from %s",
|
||||
session.model_id, bytes_seen, session.url,
|
||||
)
|
||||
return final_path
|
||||
|
||||
except DownloadCancelled:
|
||||
logging.info("[model_downloader] cancelled: %s", session.model_id)
|
||||
_remove_if_exists(tmp_path)
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"[model_downloader] failed: %s from %s: %s: %s",
|
||||
session.model_id, session.url, type(e).__name__, e,
|
||||
exc_info=True,
|
||||
)
|
||||
_remove_if_exists(tmp_path)
|
||||
raise
|
||||
finally:
|
||||
# In all terminal states (success / cancel / error) drop the
|
||||
# session from the registry. Idempotent — only removes if we're
|
||||
# still the live epoch for this model_id.
|
||||
DOWNLOAD_SERVER.finish(session)
|
||||
|
||||
|
||||
class DownloadError(Exception):
|
||||
"""Network / protocol error during a download."""
|
||||
|
||||
def __init__(self, message: str, status: Optional[int] = None) -> None:
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
|
||||
|
||||
async def _read_short(resp: aiohttp.ClientResponse, limit: int = 512) -> str:
|
||||
"""Read up to ``limit`` bytes of a response body for logging.
|
||||
|
||||
Used to surface the JSON/HTML reason from an HF non-2xx response in
|
||||
server logs instead of just the status code. Best-effort: any
|
||||
error here is swallowed.
|
||||
"""
|
||||
try:
|
||||
raw = await resp.content.read(limit)
|
||||
return raw.decode("utf-8", errors="replace").strip()
|
||||
except Exception:
|
||||
return "<unreadable>"
|
||||
|
||||
|
||||
def _auth_headers_for(url: str) -> dict[str, str]:
|
||||
"""Return any auth headers we should add to the GET for ``url``.
|
||||
|
||||
For HuggingFace URLs we inject the user's OAuth access token as a
|
||||
Bearer header — this is HF's documented way to access gated repos
|
||||
(see ``huggingface_hub.hf_hub_download``'s wire format). For every
|
||||
other host we send no extra headers; allowlisted public files
|
||||
don't need them and we don't want to leak tokens to other hosts.
|
||||
"""
|
||||
if not is_hf_url(url):
|
||||
return {}
|
||||
tok = HF_AUTH_STORE.get_token_sync()
|
||||
if tok is None or not tok.access_token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {tok.access_token}"}
|
||||
|
||||
|
||||
def _remove_if_exists(path: str) -> None:
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning("[model_downloader] could not remove %s: %s", path, e)
|
||||
|
||||
|
||||
async def run_batch_sequential(sessions: list[DownloadSession]) -> None:
|
||||
"""Run a list of sessions one after the other.
|
||||
|
||||
Each session is independent: a failure or cancellation of one does
|
||||
not abort the rest. Cancellations are observable via the registry
|
||||
*before* a given download starts, so a session that's been
|
||||
pre-cancelled (cancel before the worker reached it) just gets skipped.
|
||||
"""
|
||||
for session in sessions:
|
||||
# If the session got cancelled before its turn, skip without
|
||||
# touching disk. This is what makes the per-request "sequential
|
||||
# but cancellable" semantic work.
|
||||
if not DOWNLOAD_SERVER.is_active(session):
|
||||
DOWNLOAD_SERVER.finish(session)
|
||||
continue
|
||||
try:
|
||||
await stream_to_disk(session)
|
||||
except DownloadCancelled:
|
||||
# Already logged + tmp removed inside stream_to_disk.
|
||||
continue
|
||||
except Exception:
|
||||
# stream_to_disk already logged. Continue with the rest of the batch.
|
||||
continue
|
||||
|
||||
|
||||
def schedule_batch(sessions: list[DownloadSession]) -> asyncio.Task:
|
||||
"""Kick off ``run_batch_sequential`` on the running event loop.
|
||||
|
||||
Returned task is fire-and-forget; the API handler returns immediately
|
||||
after scheduling and clients observe progress via the polling endpoints.
|
||||
"""
|
||||
return asyncio.create_task(run_batch_sequential(sessions))
|
||||
238
app/model_downloader/gated_detection.py
Normal file
238
app/model_downloader/gated_detection.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""Per-URL probes for the unified availability endpoint.
|
||||
|
||||
Three cached/derived facts per URL:
|
||||
|
||||
- ``is_gated`` intrinsic to the model; cached forever once known.
|
||||
Determined by ``auth_check(repo_id, token=None)``:
|
||||
``GatedRepoError`` → True, success → False.
|
||||
|
||||
- ``is_hf_downloadable`` depends on the *current* token; recomputed every
|
||||
call. For non-gated URLs this is trivially True
|
||||
(no HF call needed). For gated URLs we run
|
||||
``auth_check`` with the stored token each call.
|
||||
|
||||
- ``file_size`` intrinsic to the file. Cached forever once
|
||||
determined (including ``None`` on transient
|
||||
failure — we don't retry). We only attempt the
|
||||
HEAD when we already know the URL is downloadable
|
||||
to us; that way a failed-because-gated probe
|
||||
never lands as a cached ``None``.
|
||||
|
||||
Caches are per-process, in-memory; small, no eviction needed for the
|
||||
workflow-scale (~tens of URLs). Concurrent calls for the same URL
|
||||
deduplicate via per-URL ``asyncio.Lock``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import (
|
||||
GatedRepoError,
|
||||
HfHubHTTPError,
|
||||
RepositoryNotFoundError,
|
||||
)
|
||||
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
||||
from app.model_downloader.http_client import get_session, parse_content_length
|
||||
|
||||
|
||||
_HEAD_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProbeResult:
|
||||
file_size: Optional[int]
|
||||
is_hf_downloadable: Optional[bool]
|
||||
|
||||
|
||||
# --- caches -------------------------------------------------------------- #
|
||||
|
||||
|
||||
# url → bool. Whether this URL's HF repo gates access. Intrinsic to the
|
||||
# model — never changes for a given URL.
|
||||
_is_gated_cache: dict[str, bool] = {}
|
||||
|
||||
# url → Optional[int]. The file's size in bytes, ``None`` if a probe
|
||||
# was attempted and produced no answer. **Only populated when we knew
|
||||
# the URL was downloadable to us at probe time** — so gated-without-
|
||||
# access never lands a ``None`` here that we'd be stuck with after login.
|
||||
_file_size_cache: dict[str, Optional[int]] = {}
|
||||
|
||||
# Per-URL locks for single-flight probes — when multiple polls arrive
|
||||
# in the same tick for the same URL, exactly one of them runs the HF
|
||||
# call and the others wait on the result.
|
||||
_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _lock_for(url: str) -> asyncio.Lock:
|
||||
lock = _locks.get(url)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_locks[url] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def clear_caches_for_tests() -> None:
|
||||
"""Test-only: drop everything."""
|
||||
_is_gated_cache.clear()
|
||||
_file_size_cache.clear()
|
||||
_locks.clear()
|
||||
|
||||
|
||||
# --- public entrypoint --------------------------------------------------- #
|
||||
|
||||
|
||||
async def probe_url(url: str) -> ProbeResult:
|
||||
"""Return downloadability + size for one URL, using caches where safe."""
|
||||
if not is_hf_url(url):
|
||||
# Non-HF: ``is_hf_downloadable`` is "not applicable" (None).
|
||||
# Size we still cache so we don't HEAD on every poll.
|
||||
size = await _get_or_probe_size(url, token=None)
|
||||
return ProbeResult(file_size=size, is_hf_downloadable=None)
|
||||
|
||||
repo_id = repo_id_from_url(url)
|
||||
if repo_id is None:
|
||||
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||
|
||||
# Determine intrinsic gating once.
|
||||
gated = await _resolve_is_gated(url, repo_id)
|
||||
if gated is None:
|
||||
return ProbeResult(file_size=None, is_hf_downloadable=None)
|
||||
|
||||
# Compute current-token downloadability per call.
|
||||
tok = HF_AUTH_STORE.get_token_sync()
|
||||
token_str: Optional[str] = tok.access_token if tok else None
|
||||
if not gated:
|
||||
is_hf_downloadable: Optional[bool] = True
|
||||
else:
|
||||
is_hf_downloadable = await _auth_check_with_token(repo_id, token_str)
|
||||
|
||||
if is_hf_downloadable is True:
|
||||
size = await _get_or_probe_size(url, token=token_str)
|
||||
else:
|
||||
# Skip the HEAD entirely — would 401 and we'd be stuck with
|
||||
# cached None that survives a later login.
|
||||
size = None
|
||||
|
||||
return ProbeResult(file_size=size, is_hf_downloadable=is_hf_downloadable)
|
||||
|
||||
|
||||
# --- gated/auth probes --------------------------------------------------- #
|
||||
|
||||
|
||||
async def _resolve_is_gated(url: str, repo_id: str) -> Optional[bool]:
|
||||
"""Decide once whether ``repo_id`` is a gated repo."""
|
||||
cached = _is_gated_cache.get(url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
async with _lock_for(url):
|
||||
cached = _is_gated_cache.get(url)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
await asyncio.to_thread(_auth_check_sync, repo_id, None)
|
||||
_is_gated_cache[url] = False
|
||||
return False
|
||||
except GatedRepoError:
|
||||
_is_gated_cache[url] = True
|
||||
return True
|
||||
except RepositoryNotFoundError:
|
||||
# Repo doesn't exist publicly. Treat as gated — we can't
|
||||
# serve it without auth, and an authenticated check might
|
||||
# still succeed if it's a private repo the user can see.
|
||||
_is_gated_cache[url] = True
|
||||
return True
|
||||
except (HfHubHTTPError, Exception) as e:
|
||||
logging.debug(
|
||||
"[hf_auth] is_gated probe failed for %s (will retry): %s",
|
||||
repo_id, e,
|
||||
)
|
||||
return None # don't cache; retry next call
|
||||
|
||||
|
||||
async def _auth_check_with_token(
|
||||
repo_id: str, token: Optional[str]
|
||||
) -> Optional[bool]:
|
||||
"""Auth-check with the supplied token. True/False/None per outcome."""
|
||||
try:
|
||||
await asyncio.to_thread(_auth_check_sync, repo_id, token)
|
||||
return True
|
||||
except GatedRepoError:
|
||||
return False
|
||||
except RepositoryNotFoundError:
|
||||
return False
|
||||
except HfHubHTTPError as e:
|
||||
# 401/403 covers org-SSO-required, revoked tokens, and similar —
|
||||
# all of which mean "can't fetch right now" from the user's POV.
|
||||
status = getattr(getattr(e, "response", None), "status_code", None)
|
||||
if status in (401, 403):
|
||||
return False
|
||||
logging.debug(
|
||||
"[hf_auth] auth_check transient failure for %s: %s", repo_id, e,
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.warning("[hf_auth] unexpected auth_check error for %s: %s", repo_id, e)
|
||||
return None
|
||||
|
||||
|
||||
def _auth_check_sync(repo_id: str, token: Optional[str]) -> None:
|
||||
"""Thin sync wrapper around ``HfApi.auth_check`` for ``asyncio.to_thread``."""
|
||||
HfApi().auth_check(repo_id, token=token)
|
||||
|
||||
|
||||
# --- size probe ---------------------------------------------------------- #
|
||||
|
||||
|
||||
async def _get_or_probe_size(url: str, token: Optional[str]) -> Optional[int]:
|
||||
"""Return the cached size or HEAD the URL once and cache the result."""
|
||||
if url in _file_size_cache:
|
||||
return _file_size_cache[url]
|
||||
|
||||
async with _lock_for(url):
|
||||
if url in _file_size_cache:
|
||||
return _file_size_cache[url]
|
||||
size = await _probe_size_once(url, token=token)
|
||||
_file_size_cache[url] = size
|
||||
return size
|
||||
|
||||
|
||||
async def _probe_size_once(url: str, token: Optional[str]) -> Optional[int]:
|
||||
"""HEAD the URL and return the file size in bytes, or None on failure.
|
||||
|
||||
HuggingFace serves LFS-tracked files via a 302 to a signed CDN URL.
|
||||
The real file size lives in the non-standard ``X-Linked-Size`` header
|
||||
on that 302 response (``Content-Length`` is the redirect-body length).
|
||||
Disabling redirect-follow lets us read either header on the same
|
||||
response:
|
||||
|
||||
- LFS files: 302 + ``X-Linked-Size``
|
||||
- Small/non-LFS files: 200 + ``Content-Length``
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
||||
try:
|
||||
session = await get_session()
|
||||
async with session.head(
|
||||
url, allow_redirects=False, timeout=_HEAD_TIMEOUT, headers=headers,
|
||||
) as resp:
|
||||
linked = parse_content_length(resp.headers.get("X-Linked-Size"))
|
||||
if linked is not None:
|
||||
return linked
|
||||
if resp.status == 200:
|
||||
return parse_content_length(resp.headers.get("Content-Length"))
|
||||
return None
|
||||
except (aiohttp.ClientError, TimeoutError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
# Backward-compat shim so consumers that still import the old name keep
|
||||
# building during the refactor; can be removed once routes are updated.
|
||||
MetadataProbeResult = ProbeResult
|
||||
106
app/model_downloader/hf_auth/auth_store.py
Normal file
106
app/model_downloader/hf_auth/auth_store.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""In-memory token cache with lazy disk persistence + refresh.
|
||||
|
||||
Public surface is the ``HF_AUTH_STORE`` singleton. Callers ask
|
||||
``get_valid_token()``; the store transparently refreshes from disk
|
||||
on first use, refreshes via the OAuth refresh_token if the cached
|
||||
access_token is expired, and returns ``None`` if neither path works.
|
||||
|
||||
The refresh path imports ``oauth.refresh_access_token`` lazily to
|
||||
avoid an import cycle (oauth needs the store to save tokens it
|
||||
acquires).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.model_downloader.hf_auth.token_store import (
|
||||
Token,
|
||||
delete_token,
|
||||
load_token,
|
||||
save_token,
|
||||
)
|
||||
|
||||
|
||||
class HfAuthStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._token: Optional[Token] = None
|
||||
self._loaded_from_disk = False
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Read the disk token into memory on first access."""
|
||||
if self._loaded_from_disk:
|
||||
return
|
||||
with self._lock:
|
||||
if self._loaded_from_disk:
|
||||
return
|
||||
self._token = load_token()
|
||||
self._loaded_from_disk = True
|
||||
|
||||
def has_token(self) -> bool:
|
||||
"""Cheap check: is there any token in memory?
|
||||
|
||||
Does not attempt refresh; an expired-but-refreshable token still
|
||||
counts as "logged in" from the user's perspective.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
return self._token is not None
|
||||
|
||||
def set_token(self, token: Token) -> None:
|
||||
"""Replace the in-memory token and persist to disk."""
|
||||
with self._lock:
|
||||
self._token = token
|
||||
self._loaded_from_disk = True
|
||||
save_token(token)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Forget the token in memory and on disk (logout)."""
|
||||
with self._lock:
|
||||
self._token = None
|
||||
self._loaded_from_disk = True
|
||||
delete_token()
|
||||
|
||||
def get_token_sync(self) -> Optional[Token]:
|
||||
"""Return the cached token without refreshing.
|
||||
|
||||
Sync callers (e.g. constructing an Authorization header in a
|
||||
non-async path) use this. They accept an expired token over
|
||||
``None``; HF will simply return 401 and the caller can decide
|
||||
what to do.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
return self._token
|
||||
|
||||
async def get_valid_token(self) -> Optional[Token]:
|
||||
"""Return a fresh token, refreshing via OAuth if necessary.
|
||||
|
||||
Returns ``None`` if there's no cached token at all, or if the
|
||||
cached token is expired and refresh failed. Callers should
|
||||
treat that as "user is not logged in".
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
tok = self._token
|
||||
if tok is None:
|
||||
return None
|
||||
if tok.is_valid():
|
||||
return tok
|
||||
if not tok.refresh_token:
|
||||
return None
|
||||
|
||||
# Lazy import to avoid the oauth ↔ store import cycle.
|
||||
from app.model_downloader.hf_auth.oauth import refresh_access_token
|
||||
|
||||
try:
|
||||
refreshed = await refresh_access_token(tok.refresh_token)
|
||||
except Exception as e:
|
||||
logging.warning("[hf_auth] token refresh failed: %s", e)
|
||||
return None
|
||||
self.set_token(refreshed)
|
||||
return refreshed
|
||||
|
||||
|
||||
HF_AUTH_STORE = HfAuthStore()
|
||||
55
app/model_downloader/hf_auth/eligibility.py
Normal file
55
app/model_downloader/hf_auth/eligibility.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Whether this deployment is allowed to do interactive HF OAuth.
|
||||
|
||||
We only let the server hold a HuggingFace token under a strict trust
|
||||
assumption: this is a *single tenant local* install. Concretely:
|
||||
|
||||
- The server is bound to a loopback address. SSH tunneling /
|
||||
reverse-proxies can defeat this, but it's the strongest signal
|
||||
we have without an authentication system.
|
||||
- ``--multi-user`` is off. A shared token used implicitly by multiple
|
||||
declared users would be a footgun — one user's gated downloads
|
||||
would silently authenticate as another.
|
||||
|
||||
Anything else and the frontend hides the HF login UI entirely; gated
|
||||
models continue to show the "acquire it manually" message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
def _is_loopback(host: str | None) -> bool:
|
||||
"""Duplicates ``server.is_loopback`` (small, no shared module yet).
|
||||
|
||||
Resolves a host or IP literal to whether it lives on the loopback
|
||||
interface (127.0.0.0/8 for IPv4, ::1 for IPv6). Returns False for
|
||||
``0.0.0.0`` / ``::`` because those are bind-all wildcards, not
|
||||
loopback.
|
||||
"""
|
||||
if host is None:
|
||||
return False
|
||||
try:
|
||||
return ipaddress.ip_address(host).is_loopback
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
loopback = False
|
||||
for family in (socket.AF_INET, socket.AF_INET6):
|
||||
try:
|
||||
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
|
||||
for _family, _, _, _, sockaddr in r:
|
||||
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
|
||||
return loopback
|
||||
loopback = True
|
||||
except socket.gaierror:
|
||||
pass
|
||||
return loopback
|
||||
|
||||
|
||||
def is_hf_auth_eligible() -> bool:
|
||||
"""True iff this deployment may surface the HF OAuth flow."""
|
||||
return _is_loopback(args.listen) and not args.multi_user
|
||||
277
app/model_downloader/hf_auth/oauth.py
Normal file
277
app/model_downloader/hf_auth/oauth.py
Normal file
@ -0,0 +1,277 @@
|
||||
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
|
||||
|
||||
Wired so that ``POST /api/hf-auth-login-start`` can:
|
||||
1. Generate state + PKCE verifier/challenge in this process.
|
||||
2. Spin up a short-lived loopback HTTP server at port 41954 to
|
||||
receive the redirect callback from HF.
|
||||
3. Return the ``authorize_url`` for the frontend to open in a new tab.
|
||||
|
||||
After the user grants consent on huggingface.co, HF redirects to the
|
||||
local callback URL with ``code`` and ``state``. The callback server
|
||||
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
|
||||
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
|
||||
itself down.
|
||||
|
||||
Before this can be exercised end-to-end a maintainer must register a
|
||||
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
|
||||
below. See the comment above the constant for the exact steps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
||||
from app.model_downloader.hf_auth.token_store import Token
|
||||
from app.model_downloader.http_client import ssl_context
|
||||
|
||||
|
||||
# --- HF OAuth app registration -------------------------------------------- #
|
||||
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
|
||||
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
|
||||
# under a Comfy-Org-controlled HF account and substitute its client_id here.
|
||||
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
|
||||
# ("HuggingFace OAuth app setup" section). Short version:
|
||||
# 1. huggingface.co → Settings → Connected Apps → "Create app"
|
||||
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
|
||||
# ``gated-repos`` (Repository Access). Leave everything else off.
|
||||
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
|
||||
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
|
||||
# change ``CALLBACK_PORT``.
|
||||
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
|
||||
# The client_id is not a secret (it travels through the user's browser in
|
||||
# plaintext); HF's "Public app" type means there's no client secret to
|
||||
# manage — PKCE replaces it.
|
||||
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
|
||||
|
||||
CALLBACK_HOST = "127.0.0.1"
|
||||
CALLBACK_PORT = 41954
|
||||
CALLBACK_PATH = "/api/auth/huggingface/callback"
|
||||
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
|
||||
|
||||
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
|
||||
TOKEN_URL = "https://huggingface.co/oauth/token"
|
||||
# Minimal scope set for the feature:
|
||||
# - openid : required by HF when the app uses OIDC at all
|
||||
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
|
||||
# settings UI; cosmetic but expected
|
||||
# - gated-repos : grants the token enough to call ``auth_check`` and
|
||||
# download files from public gated repos the user has
|
||||
# accepted the license for. The wider ``read-repos`` scope
|
||||
# would also work (it includes ``gated-repos``) but it
|
||||
# additionally grants private-repo read access, which we
|
||||
# don't need and which makes the consent screen scarier
|
||||
# for the user.
|
||||
SCOPE = "openid profile gated-repos"
|
||||
|
||||
# Maximum time the callback server stays up waiting for the user to
|
||||
# complete consent on huggingface.co. Past this, the port closes and
|
||||
# the user has to click "Log in" again.
|
||||
CALLBACK_TIMEOUT_SECS = 300
|
||||
|
||||
|
||||
# Process-wide lock so two simultaneous /api/hf-auth-login-start
|
||||
# requests don't fight over port CALLBACK_PORT.
|
||||
_OAUTH_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class OAuthInProgressError(Exception):
|
||||
"""Another OAuth attempt is already running."""
|
||||
|
||||
|
||||
class OAuthCallbackError(Exception):
|
||||
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
|
||||
|
||||
|
||||
# --- PKCE primitives ------------------------------------------------------ #
|
||||
|
||||
|
||||
def _make_pkce() -> tuple[str, str, str]:
|
||||
"""Return ``(verifier, challenge, state)``.
|
||||
|
||||
Verifier never leaves this process. Challenge and state travel
|
||||
through the user's browser. State is checked on the callback to
|
||||
prevent a malicious cross-origin redirect from injecting a token.
|
||||
"""
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
||||
.rstrip(b"=")
|
||||
.decode("ascii")
|
||||
)
|
||||
state = secrets.token_urlsafe(32)
|
||||
return verifier, challenge, state
|
||||
|
||||
|
||||
def _build_authorize_url(challenge: str, state: str) -> str:
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
"client_id": HF_CLIENT_ID,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"response_type": "code",
|
||||
"scope": SCOPE,
|
||||
"state": state,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
return f"{AUTHORIZE_URL}?{urlencode(params)}"
|
||||
|
||||
|
||||
# --- Token exchange ------------------------------------------------------- #
|
||||
|
||||
|
||||
async def _exchange_code(code: str, verifier: str) -> Token:
|
||||
"""Trade the authorization code for an access+refresh token pair."""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": REDIRECT_URI,
|
||||
"client_id": HF_CLIENT_ID,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
|
||||
resp.raise_for_status()
|
||||
body = await resp.json()
|
||||
return Token(
|
||||
access_token=body["access_token"],
|
||||
refresh_token=body.get("refresh_token"),
|
||||
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
||||
scope=body.get("scope", SCOPE),
|
||||
)
|
||||
|
||||
|
||||
async def refresh_access_token(refresh_token: str) -> Token:
|
||||
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": HF_CLIENT_ID,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
|
||||
resp.raise_for_status()
|
||||
body = await resp.json()
|
||||
return Token(
|
||||
access_token=body["access_token"],
|
||||
# If HF doesn't rotate refresh tokens, keep using the existing one.
|
||||
refresh_token=body.get("refresh_token", refresh_token),
|
||||
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
||||
scope=body.get("scope", SCOPE),
|
||||
)
|
||||
|
||||
|
||||
# --- Callback server ------------------------------------------------------ #
|
||||
|
||||
|
||||
async def start_login_flow() -> str:
|
||||
"""Begin one OAuth attempt: spawn the callback server, return the URL.
|
||||
|
||||
Returns the URL the frontend should open in a new tab. Raises
|
||||
``OAuthInProgressError`` if another attempt is already running.
|
||||
The callback server runs in the background until the user
|
||||
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
|
||||
either way the lock + port are released afterward.
|
||||
"""
|
||||
if not _OAUTH_LOCK.acquire(blocking=False):
|
||||
raise OAuthInProgressError()
|
||||
|
||||
verifier, challenge, state = _make_pkce()
|
||||
authorize_url = _build_authorize_url(challenge, state)
|
||||
|
||||
# Fire the callback server on the running loop and return.
|
||||
asyncio.create_task(_run_callback_server(verifier, state))
|
||||
return authorize_url
|
||||
|
||||
|
||||
async def _run_callback_server(verifier: str, expected_state: str) -> None:
|
||||
"""Listen for HF's redirect once, capture the token, then shut down."""
|
||||
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
try:
|
||||
if request.query.get("state") != expected_state:
|
||||
return web.Response(status=400, text="state mismatch")
|
||||
err = request.query.get("error")
|
||||
if err:
|
||||
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
|
||||
return web.Response(status=400, text=f"OAuth error: {err}")
|
||||
code = request.query.get("code")
|
||||
if not code:
|
||||
return web.Response(status=400, text="missing code")
|
||||
tok = await _exchange_code(code, verifier)
|
||||
if not received.done():
|
||||
received.set_result(tok)
|
||||
return web.Response(
|
||||
content_type="text/html",
|
||||
text=(
|
||||
"<html><body style='font-family:sans-serif;padding:40px'>"
|
||||
"<h2>HuggingFace login successful</h2>"
|
||||
"<p>You can close this tab and return to ComfyUI.</p>"
|
||||
"</body></html>"
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
if not received.done():
|
||||
received.set_exception(exc)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get(CALLBACK_PATH, handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
|
||||
try:
|
||||
await site.start()
|
||||
except OSError as e:
|
||||
# Port already in use (or some other socket-bind failure). Release
|
||||
# the lock so a future attempt has a chance to succeed.
|
||||
logging.warning("[hf_auth] could not bind callback port: %s", e)
|
||||
_OAUTH_LOCK.release()
|
||||
return
|
||||
|
||||
try:
|
||||
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
|
||||
return
|
||||
except OAuthCallbackError as e:
|
||||
logging.warning("[hf_auth] OAuth callback error: %s", e)
|
||||
return
|
||||
except Exception as e:
|
||||
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
|
||||
return
|
||||
else:
|
||||
HF_AUTH_STORE.set_token(token)
|
||||
logging.info("[hf_auth] OAuth login complete")
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
if _OAUTH_LOCK.locked():
|
||||
_OAUTH_LOCK.release()
|
||||
|
||||
|
||||
def is_login_in_progress() -> bool:
|
||||
"""True iff a callback server is currently bound + waiting."""
|
||||
return _OAUTH_LOCK.locked()
|
||||
|
||||
|
||||
# Re-export for callers that only want the URL builder (e.g. tests).
|
||||
__all__ = [
|
||||
"start_login_flow",
|
||||
"refresh_access_token",
|
||||
"is_login_in_progress",
|
||||
"OAuthInProgressError",
|
||||
"CALLBACK_TIMEOUT_SECS",
|
||||
]
|
||||
89
app/model_downloader/hf_auth/token_store.py
Normal file
89
app/model_downloader/hf_auth/token_store.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""On-disk persistence for the HuggingFace OAuth token.
|
||||
|
||||
The token shape mirrors what HF returns on the token exchange: an
|
||||
``access_token``, an optional ``refresh_token``, the absolute epoch at
|
||||
which the access token expires, and the granted scope. We persist
|
||||
this so logging in once survives ComfyUI restarts; the file is mode
|
||||
``0600`` so only the OS user can read it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Treat a token as expired this many seconds before its server-reported
|
||||
# ``expires_at`` so we don't try to use a token mid-request only for it
|
||||
# to flip stale between auth_check and the actual GET.
|
||||
EXPIRY_BUFFER_SECS = 60
|
||||
|
||||
TOKEN_FILENAME = "hf_auth_token.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
"""One OAuth token + the metadata we need to use it."""
|
||||
access_token: str
|
||||
refresh_token: Optional[str]
|
||||
expires_at: float # absolute epoch seconds
|
||||
scope: str = ""
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""True iff we can use this token right now."""
|
||||
return (
|
||||
bool(self.access_token)
|
||||
and (self.expires_at - time.time() > EXPIRY_BUFFER_SECS)
|
||||
)
|
||||
|
||||
|
||||
def _token_path() -> str:
|
||||
base = folder_paths.get_user_directory()
|
||||
return os.path.join(base, TOKEN_FILENAME)
|
||||
|
||||
|
||||
def load_token() -> Optional[Token]:
|
||||
"""Read the persisted token, returning ``None`` if absent or corrupt."""
|
||||
path = _token_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return Token(**data)
|
||||
except (OSError, json.JSONDecodeError, TypeError) as e:
|
||||
logging.warning("[hf_auth] could not load token at %s: %s", path, e)
|
||||
return None
|
||||
|
||||
|
||||
def save_token(token: Token) -> None:
|
||||
"""Atomically write the token with 0600 permissions."""
|
||||
path = _token_path()
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
tmp = path + ".tmp"
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(asdict(token), f)
|
||||
os.replace(tmp, path)
|
||||
try:
|
||||
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
except OSError as e:
|
||||
# On Windows / weird filesystems chmod may be a no-op; not fatal.
|
||||
logging.debug("[hf_auth] chmod 0600 on %s failed: %s", path, e)
|
||||
|
||||
|
||||
def delete_token() -> None:
|
||||
"""Remove the persisted token; no-op if it doesn't exist."""
|
||||
path = _token_path()
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError as e:
|
||||
logging.warning("[hf_auth] could not remove token at %s: %s", path, e)
|
||||
41
app/model_downloader/hf_url.py
Normal file
41
app/model_downloader/hf_url.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Parsers for the ``huggingface.co`` URL shape we accept in workflows.
|
||||
|
||||
The download API accepts URLs of the form
|
||||
``https://huggingface.co/<org>/<repo>/resolve/<rev>/<path/to/file>``.
|
||||
We need to recover ``<org>/<repo>`` (the *repo_id*) from such URLs for
|
||||
``huggingface_hub`` API calls (notably ``HfApi.auth_check``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
_HF_HOST = "huggingface.co"
|
||||
|
||||
|
||||
def is_hf_url(url: str) -> bool:
|
||||
"""Cheap host check — does this URL point at huggingface.co?"""
|
||||
try:
|
||||
return urlparse(url).hostname == _HF_HOST
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def repo_id_from_url(url: str) -> Optional[str]:
|
||||
"""Extract ``<org>/<repo>`` from an HF model file URL.
|
||||
|
||||
Returns ``None`` if the URL isn't on huggingface.co or doesn't look
|
||||
like a model-file URL. The expected shape is
|
||||
``/<org>/<repo>/resolve/<rev>/<path>`` — anything else
|
||||
(datasets, spaces, /tree/, /blob/, …) we treat as out of scope here.
|
||||
"""
|
||||
if not is_hf_url(url):
|
||||
return None
|
||||
parts = urlparse(url).path.lstrip("/").split("/")
|
||||
if len(parts) < 4 or parts[2] != "resolve":
|
||||
return None
|
||||
org, repo = parts[0], parts[1]
|
||||
if not org or not repo:
|
||||
return None
|
||||
return f"{org}/{repo}"
|
||||
63
app/model_downloader/http_client.py
Normal file
63
app/model_downloader/http_client.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Lazy module-level aiohttp ClientSession.
|
||||
|
||||
A single shared session means TLS handshakes are reused across HEAD probes
|
||||
and the subsequent GETs to the same host (HuggingFace is the dominant
|
||||
case), which is a noticeable speedup on cold connections.
|
||||
|
||||
We deliberately don't close the session at process exit — aiohttp's
|
||||
warning about unclosed sessions is benign at shutdown, and adding atexit
|
||||
plumbing buys nothing because the OS reclaims the sockets anyway. The
|
||||
session lifetime is the lifetime of the Python process.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import certifi
|
||||
|
||||
|
||||
# Larger per-host pool than aiohttp's default (=100 total / =0 per host)
|
||||
# so concurrent gated probes + a download to the same host don't queue.
|
||||
_CONNECTOR_LIMIT_PER_HOST = 8
|
||||
|
||||
_session: Optional[aiohttp.ClientSession] = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def ssl_context() -> ssl.SSLContext:
|
||||
"""TLS context pinned to certifi's CA bundle.
|
||||
aiohttp's default context uses the OS trust store, which isn't wired up
|
||||
on some Python installs (python.org macOS, slim containers) — there TLS
|
||||
to huggingface.co fails with CERTIFICATE_VERIFY_FAILED.
|
||||
"""
|
||||
return ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
|
||||
async def get_session() -> aiohttp.ClientSession:
|
||||
"""Return the shared session, creating it on first call."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit_per_host=_CONNECTOR_LIMIT_PER_HOST,
|
||||
ssl=ssl_context(),
|
||||
)
|
||||
_session = aiohttp.ClientSession(connector=connector)
|
||||
return _session
|
||||
|
||||
|
||||
def parse_content_length(value: Optional[str]) -> Optional[int]:
|
||||
"""Parse a byte-count header value, or None if absent/malformed/negative."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
n = int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
return n if n >= 0 else None
|
||||
93
app/model_downloader/paths.py
Normal file
93
app/model_downloader/paths.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Path resolution for model downloads.
|
||||
|
||||
Model identifiers used across the download API are *relative destination
|
||||
paths* of the form ``<directory>/<filename>`` (e.g. ``loras/my_lora.safetensors``).
|
||||
This module turns one of those identifiers into an absolute on-disk path
|
||||
under one of ComfyUI's registered model folders, while rejecting unknown
|
||||
folders, path traversal, and other ill-formed inputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
# Constrain components so a model_id can never escape its target directory.
|
||||
# - directory: a single path segment of safe chars
|
||||
# - filename: a single path segment of safe chars, must end with a model extension
|
||||
_SEGMENT_RE = re.compile(r"^[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
class InvalidModelId(ValueError):
|
||||
"""Raised when a model_id is syntactically invalid or refers to an
|
||||
unknown model folder."""
|
||||
|
||||
|
||||
def parse_model_id(model_id: str) -> Tuple[str, str]:
|
||||
"""Split ``<directory>/<filename>`` and validate both components.
|
||||
|
||||
Returns ``(directory, filename)``. Raises ``InvalidModelId`` on
|
||||
malformed input. Does NOT touch the filesystem.
|
||||
"""
|
||||
if not isinstance(model_id, str) or "/" not in model_id:
|
||||
raise InvalidModelId(f"model_id must be '<directory>/<filename>', got {model_id!r}")
|
||||
directory, _, filename = model_id.partition("/")
|
||||
if "/" in filename or not directory or not filename:
|
||||
raise InvalidModelId(f"model_id must be exactly one '/' separator, got {model_id!r}")
|
||||
if not _SEGMENT_RE.match(directory):
|
||||
raise InvalidModelId(f"invalid directory segment {directory!r}")
|
||||
if not _SEGMENT_RE.match(filename):
|
||||
raise InvalidModelId(f"invalid filename segment {filename!r}")
|
||||
if directory not in folder_paths.folder_names_and_paths:
|
||||
raise InvalidModelId(f"unknown model folder {directory!r}")
|
||||
return directory, filename
|
||||
|
||||
|
||||
def resolve_existing(model_id: str) -> Optional[str]:
|
||||
"""Return the absolute path of an installed model, or None if missing.
|
||||
|
||||
Honours ``extra_model_paths.yaml`` transparently via
|
||||
``folder_paths.get_full_path``.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id)
|
||||
return folder_paths.get_full_path(directory, filename)
|
||||
|
||||
|
||||
def resolve_destination(model_id: str) -> Tuple[str, str]:
|
||||
"""Return ``(final_path, tmp_path)`` for a download.
|
||||
|
||||
Downloads land at the first registered path for the model's directory
|
||||
(the "primary" location). The ``.tmp`` sibling is used as the write
|
||||
target and atomically renamed on success.
|
||||
"""
|
||||
directory, filename = parse_model_id(model_id)
|
||||
roots = folder_paths.get_folder_paths(directory)
|
||||
if not roots:
|
||||
raise InvalidModelId(f"no on-disk path registered for folder {directory!r}")
|
||||
root = roots[0]
|
||||
final_path = os.path.join(root, filename)
|
||||
tmp_path = final_path + ".tmp"
|
||||
return final_path, tmp_path
|
||||
|
||||
|
||||
def iter_all_tmp_paths():
|
||||
"""Yield every ``*.tmp`` file under every registered model folder.
|
||||
|
||||
Used at startup to sweep orphans left by crashed/restarted downloads.
|
||||
"""
|
||||
seen_roots: set[str] = set()
|
||||
for directory in folder_paths.folder_names_and_paths.keys():
|
||||
for root in folder_paths.get_folder_paths(directory):
|
||||
if root in seen_roots or not os.path.isdir(root):
|
||||
continue
|
||||
seen_roots.add(root)
|
||||
try:
|
||||
for entry in os.scandir(root):
|
||||
if entry.is_file() and entry.name.endswith(".tmp"):
|
||||
yield entry.path
|
||||
except OSError:
|
||||
# Folder might be unreadable / missing on certain mounts —
|
||||
# not fatal, just skip it.
|
||||
continue
|
||||
Loading…
Reference in New Issue
Block a user