mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add support for ENV based HF_TOKEN.
This commit is contained in:
parent
8a6e7906f7
commit
64c5853631
@ -17,6 +17,15 @@ AUTH_SCHEME_HEADER = "header"
|
||||
AUTH_SCHEME_QUERY = "query"
|
||||
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
|
||||
|
||||
# Hosts for which a bearer token can be sourced from the environment when no
|
||||
# stored credential matches. Values are the env var names to try, in order.
|
||||
# Only consulted during auto-resolve for an exact host match over https, so the
|
||||
# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to
|
||||
# a CDN host). Kept here so the host->env-var mapping lives in one place.
|
||||
ENV_TOKEN_HOSTS = {
|
||||
"huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"),
|
||||
}
|
||||
|
||||
|
||||
class DownloadStatus:
|
||||
QUEUED = "queued"
|
||||
|
||||
@ -10,6 +10,7 @@ which is exactly what these hubs expect.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlsplit, urlunsplit
|
||||
@ -18,6 +19,7 @@ from app.model_downloader.constants import (
|
||||
AUTH_SCHEME_BEARER,
|
||||
AUTH_SCHEME_HEADER,
|
||||
AUTH_SCHEME_QUERY,
|
||||
ENV_TOKEN_HOSTS,
|
||||
)
|
||||
from app.model_downloader.credentials.store import normalize_host
|
||||
from app.model_downloader.database import queries
|
||||
@ -89,6 +91,14 @@ def _resolve_sync(
|
||||
for sub in queries.list_subdomain_credentials():
|
||||
if sub.enabled and _matches(sub, hop_host):
|
||||
return _build_auth(sub)
|
||||
|
||||
# Env fallback: only for an exact host match, and only after the DB lookups
|
||||
# miss, so a user-set credential always takes precedence. The token is never
|
||||
# persisted; it is read fresh from the environment on each hop.
|
||||
for var in ENV_TOKEN_HOSTS.get(hop_host, ()):
|
||||
token = os.environ.get(var)
|
||||
if token:
|
||||
return RequestAuth(headers={"Authorization": f"Bearer {token}"})
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -108,3 +108,59 @@ def test_resolver_never_crosses_host_boundary():
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- env-based HF token fallback -----
|
||||
|
||||
|
||||
def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
# exact host over https -> env token attached
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer env_hf_token"
|
||||
# non-https hop -> never attached
|
||||
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||
# CDN redirect host -> dropped (exact-host only)
|
||||
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_env_token_secondary_var_is_honored(monkeypatch):
|
||||
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token")
|
||||
|
||||
async def _run():
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer env_hub_token"
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_db_credential_takes_precedence_over_env(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key")
|
||||
try:
|
||||
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||
assert auth is not None
|
||||
assert auth.headers["Authorization"] == "Bearer db_secret_key"
|
||||
finally:
|
||||
await CREDENTIAL_STORE.delete(view.id)
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_env_token_does_not_leak_into_explicit_path(monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||
|
||||
async def _run():
|
||||
# An explicit credential id that doesn't resolve must stay None; the env
|
||||
# fallback only applies to the auto-resolve branch.
|
||||
auth = await resolver.resolve_auth_for_hop(
|
||||
"huggingface.co", "https", explicit_credential_id="does-not-exist"
|
||||
)
|
||||
assert auth is None
|
||||
asyncio.run(_run())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user