mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 08:19:32 +08:00
Self-contained package under app/model_downloader/: - Allowlist + path-validated downloads (SSRF guard: HF/Civitai/localhost + model extension). - Streaming worker: writes to <final>.tmp, atomic rename on success, cooperative cancellation with epoch-based session identity, orphan .tmp sweep. - Unified availability probe with per-URL gated/size caching; is_hf_downloadable recomputed per call so login/license changes surface within one poll. - HuggingFace OAuth 2.0 PKCE flow with loopback callback server and on-disk (0600) token storage + transparent refresh. - Pydantic request/response schemas and aiohttp routes under api/. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
278 lines
10 KiB
Python
278 lines
10 KiB
Python
"""OAuth 2.0 PKCE flow against HuggingFace's authorization server.
|
|
|
|
Wired so that ``POST /api/hf-auth-login-start`` can:
|
|
1. Generate state + PKCE verifier/challenge in this process.
|
|
2. Spin up a short-lived loopback HTTP server at port 41954 to
|
|
receive the redirect callback from HF.
|
|
3. Return the ``authorize_url`` for the frontend to open in a new tab.
|
|
|
|
After the user grants consent on huggingface.co, HF redirects to the
|
|
local callback URL with ``code`` and ``state``. The callback server
|
|
validates ``state`` (CSRF), exchanges the code for tokens via PKCE,
|
|
hands the resulting Token to ``HF_AUTH_STORE.set_token``, and shuts
|
|
itself down.
|
|
|
|
Before this can be exercised end-to-end a maintainer must register a
|
|
HuggingFace OAuth app and substitute the ``HF_CLIENT_ID`` placeholder
|
|
below. See the comment above the constant for the exact steps.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import logging
|
|
import secrets
|
|
import threading
|
|
import time
|
|
|
|
import aiohttp
|
|
from aiohttp import web
|
|
|
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE
|
|
from app.model_downloader.hf_auth.token_store import Token
|
|
from app.model_downloader.http_client import ssl_context
|
|
|
|
|
|
# --- HF OAuth app registration -------------------------------------------- #
|
|
# NOTE: The OAuth client_id below is a placeholder. Before this feature can be
|
|
# exercised end-to-end, a maintainer must register a HuggingFace OAuth app
|
|
# under a Comfy-Org-controlled HF account and substitute its client_id here.
|
|
# Detailed walkthrough is in docs/server-side-model-downloads-handover.html
|
|
# ("HuggingFace OAuth app setup" section). Short version:
|
|
# 1. huggingface.co → Settings → Connected Apps → "Create app"
|
|
# 2. Default Scopes: check ``openid`` + ``profile`` (User Info) and
|
|
# ``gated-repos`` (Repository Access). Leave everything else off.
|
|
# 3. Redirect URLs: exactly ``http://127.0.0.1:41954/api/auth/huggingface/callback``
|
|
# — must match ``REDIRECT_URI`` below; change both in lockstep if you
|
|
# change ``CALLBACK_PORT``.
|
|
# 4. Save → copy the resulting Client ID into ``HF_CLIENT_ID`` below.
|
|
# The client_id is not a secret (it travels through the user's browser in
|
|
# plaintext); HF's "Public app" type means there's no client secret to
|
|
# manage — PKCE replaces it.
|
|
HF_CLIENT_ID = "REPLACE_ME_WITH_COMFY_ORG_HF_OAUTH_CLIENT_ID"
|
|
|
|
CALLBACK_HOST = "127.0.0.1"
|
|
CALLBACK_PORT = 41954
|
|
CALLBACK_PATH = "/api/auth/huggingface/callback"
|
|
REDIRECT_URI = f"http://{CALLBACK_HOST}:{CALLBACK_PORT}{CALLBACK_PATH}"
|
|
|
|
AUTHORIZE_URL = "https://huggingface.co/oauth/authorize"
|
|
TOKEN_URL = "https://huggingface.co/oauth/token"
|
|
# Minimal scope set for the feature:
|
|
# - openid : required by HF when the app uses OIDC at all
|
|
# - profile : lets ``HfApi.whoami(token=...)`` return a username for the
|
|
# settings UI; cosmetic but expected
|
|
# - gated-repos : grants the token enough to call ``auth_check`` and
|
|
# download files from public gated repos the user has
|
|
# accepted the license for. The wider ``read-repos`` scope
|
|
# would also work (it includes ``gated-repos``) but it
|
|
# additionally grants private-repo read access, which we
|
|
# don't need and which makes the consent screen scarier
|
|
# for the user.
|
|
SCOPE = "openid profile gated-repos"
|
|
|
|
# Maximum time the callback server stays up waiting for the user to
|
|
# complete consent on huggingface.co. Past this, the port closes and
|
|
# the user has to click "Log in" again.
|
|
CALLBACK_TIMEOUT_SECS = 300
|
|
|
|
|
|
# Process-wide lock so two simultaneous /api/hf-auth-login-start
|
|
# requests don't fight over port CALLBACK_PORT.
|
|
_OAUTH_LOCK = threading.Lock()
|
|
|
|
|
|
class OAuthInProgressError(Exception):
|
|
"""Another OAuth attempt is already running."""
|
|
|
|
|
|
class OAuthCallbackError(Exception):
|
|
"""The OAuth callback returned an error (HF denied, port stolen, etc.)."""
|
|
|
|
|
|
# --- PKCE primitives ------------------------------------------------------ #
|
|
|
|
|
|
def _make_pkce() -> tuple[str, str, str]:
|
|
"""Return ``(verifier, challenge, state)``.
|
|
|
|
Verifier never leaves this process. Challenge and state travel
|
|
through the user's browser. State is checked on the callback to
|
|
prevent a malicious cross-origin redirect from injecting a token.
|
|
"""
|
|
verifier = secrets.token_urlsafe(64)
|
|
challenge = (
|
|
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
|
.rstrip(b"=")
|
|
.decode("ascii")
|
|
)
|
|
state = secrets.token_urlsafe(32)
|
|
return verifier, challenge, state
|
|
|
|
|
|
def _build_authorize_url(challenge: str, state: str) -> str:
|
|
from urllib.parse import urlencode
|
|
|
|
params = {
|
|
"client_id": HF_CLIENT_ID,
|
|
"redirect_uri": REDIRECT_URI,
|
|
"response_type": "code",
|
|
"scope": SCOPE,
|
|
"state": state,
|
|
"code_challenge": challenge,
|
|
"code_challenge_method": "S256",
|
|
}
|
|
return f"{AUTHORIZE_URL}?{urlencode(params)}"
|
|
|
|
|
|
# --- Token exchange ------------------------------------------------------- #
|
|
|
|
|
|
async def _exchange_code(code: str, verifier: str) -> Token:
|
|
"""Trade the authorization code for an access+refresh token pair."""
|
|
data = {
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": REDIRECT_URI,
|
|
"client_id": HF_CLIENT_ID,
|
|
"code_verifier": verifier,
|
|
}
|
|
timeout = aiohttp.ClientTimeout(total=30)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
|
|
resp.raise_for_status()
|
|
body = await resp.json()
|
|
return Token(
|
|
access_token=body["access_token"],
|
|
refresh_token=body.get("refresh_token"),
|
|
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
|
scope=body.get("scope", SCOPE),
|
|
)
|
|
|
|
|
|
async def refresh_access_token(refresh_token: str) -> Token:
|
|
"""Trade a refresh_token for a new access (+ possibly refresh) token."""
|
|
data = {
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": HF_CLIENT_ID,
|
|
}
|
|
timeout = aiohttp.ClientTimeout(total=30)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(TOKEN_URL, data=data, ssl=ssl_context()) as resp:
|
|
resp.raise_for_status()
|
|
body = await resp.json()
|
|
return Token(
|
|
access_token=body["access_token"],
|
|
# If HF doesn't rotate refresh tokens, keep using the existing one.
|
|
refresh_token=body.get("refresh_token", refresh_token),
|
|
expires_at=time.time() + float(body.get("expires_in", 3600)),
|
|
scope=body.get("scope", SCOPE),
|
|
)
|
|
|
|
|
|
# --- Callback server ------------------------------------------------------ #
|
|
|
|
|
|
async def start_login_flow() -> str:
|
|
"""Begin one OAuth attempt: spawn the callback server, return the URL.
|
|
|
|
Returns the URL the frontend should open in a new tab. Raises
|
|
``OAuthInProgressError`` if another attempt is already running.
|
|
The callback server runs in the background until the user
|
|
completes consent or until ``CALLBACK_TIMEOUT_SECS`` elapses;
|
|
either way the lock + port are released afterward.
|
|
"""
|
|
if not _OAUTH_LOCK.acquire(blocking=False):
|
|
raise OAuthInProgressError()
|
|
|
|
verifier, challenge, state = _make_pkce()
|
|
authorize_url = _build_authorize_url(challenge, state)
|
|
|
|
# Fire the callback server on the running loop and return.
|
|
asyncio.create_task(_run_callback_server(verifier, state))
|
|
return authorize_url
|
|
|
|
|
|
async def _run_callback_server(verifier: str, expected_state: str) -> None:
|
|
"""Listen for HF's redirect once, capture the token, then shut down."""
|
|
received: asyncio.Future[Token] = asyncio.get_event_loop().create_future()
|
|
|
|
async def handler(request: web.Request) -> web.Response:
|
|
try:
|
|
if request.query.get("state") != expected_state:
|
|
return web.Response(status=400, text="state mismatch")
|
|
err = request.query.get("error")
|
|
if err:
|
|
received.set_exception(OAuthCallbackError(f"HF returned: {err}"))
|
|
return web.Response(status=400, text=f"OAuth error: {err}")
|
|
code = request.query.get("code")
|
|
if not code:
|
|
return web.Response(status=400, text="missing code")
|
|
tok = await _exchange_code(code, verifier)
|
|
if not received.done():
|
|
received.set_result(tok)
|
|
return web.Response(
|
|
content_type="text/html",
|
|
text=(
|
|
"<html><body style='font-family:sans-serif;padding:40px'>"
|
|
"<h2>HuggingFace login successful</h2>"
|
|
"<p>You can close this tab and return to ComfyUI.</p>"
|
|
"</body></html>"
|
|
),
|
|
)
|
|
except Exception as exc:
|
|
if not received.done():
|
|
received.set_exception(exc)
|
|
return web.Response(status=500, text=str(exc))
|
|
|
|
app = web.Application()
|
|
app.router.add_get(CALLBACK_PATH, handler)
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, CALLBACK_HOST, CALLBACK_PORT, reuse_address=True)
|
|
try:
|
|
await site.start()
|
|
except OSError as e:
|
|
# Port already in use (or some other socket-bind failure). Release
|
|
# the lock so a future attempt has a chance to succeed.
|
|
logging.warning("[hf_auth] could not bind callback port: %s", e)
|
|
_OAUTH_LOCK.release()
|
|
return
|
|
|
|
try:
|
|
token = await asyncio.wait_for(received, timeout=CALLBACK_TIMEOUT_SECS)
|
|
except asyncio.TimeoutError:
|
|
logging.info("[hf_auth] OAuth login timed out after %ds", CALLBACK_TIMEOUT_SECS)
|
|
return
|
|
except OAuthCallbackError as e:
|
|
logging.warning("[hf_auth] OAuth callback error: %s", e)
|
|
return
|
|
except Exception as e:
|
|
logging.warning("[hf_auth] unexpected OAuth failure: %s", e)
|
|
return
|
|
else:
|
|
HF_AUTH_STORE.set_token(token)
|
|
logging.info("[hf_auth] OAuth login complete")
|
|
finally:
|
|
await runner.cleanup()
|
|
if _OAUTH_LOCK.locked():
|
|
_OAUTH_LOCK.release()
|
|
|
|
|
|
def is_login_in_progress() -> bool:
|
|
"""True iff a callback server is currently bound + waiting."""
|
|
return _OAUTH_LOCK.locked()
|
|
|
|
|
|
# Re-export for callers that only want the URL builder (e.g. tests).
|
|
__all__ = [
|
|
"start_login_flow",
|
|
"refresh_access_token",
|
|
"is_login_in_progress",
|
|
"OAuthInProgressError",
|
|
"CALLBACK_TIMEOUT_SECS",
|
|
]
|