Add support for ENV based HF_TOKEN.

This commit is contained in:
Talmaj Marinc 2026-07-01 12:02:19 +02:00
parent 8a6e7906f7
commit 64c5853631
3 changed files with 75 additions and 0 deletions

View File

@ -17,6 +17,15 @@ AUTH_SCHEME_HEADER = "header"
AUTH_SCHEME_QUERY = "query" AUTH_SCHEME_QUERY = "query"
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY) AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
# Hosts for which a bearer token can be sourced from the environment when no
# stored credential matches. Values are the env var names to try, in order.
# Only consulted during auto-resolve for an exact host match over https, so the
# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to
# a CDN host). Kept here so the host->env-var mapping lives in one place.
ENV_TOKEN_HOSTS = {
"huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"),
}
class DownloadStatus: class DownloadStatus:
QUEUED = "queued" QUEUED = "queued"

View File

@ -10,6 +10,7 @@ which is exactly what these hubs expect.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from urllib.parse import urlencode, urlsplit, urlunsplit from urllib.parse import urlencode, urlsplit, urlunsplit
@ -18,6 +19,7 @@ from app.model_downloader.constants import (
AUTH_SCHEME_BEARER, AUTH_SCHEME_BEARER,
AUTH_SCHEME_HEADER, AUTH_SCHEME_HEADER,
AUTH_SCHEME_QUERY, AUTH_SCHEME_QUERY,
ENV_TOKEN_HOSTS,
) )
from app.model_downloader.credentials.store import normalize_host from app.model_downloader.credentials.store import normalize_host
from app.model_downloader.database import queries from app.model_downloader.database import queries
@ -89,6 +91,14 @@ def _resolve_sync(
for sub in queries.list_subdomain_credentials(): for sub in queries.list_subdomain_credentials():
if sub.enabled and _matches(sub, hop_host): if sub.enabled and _matches(sub, hop_host):
return _build_auth(sub) return _build_auth(sub)
# Env fallback: only for an exact host match, and only after the DB lookups
# miss, so a user-set credential always takes precedence. The token is never
# persisted; it is read fresh from the environment on each hop.
for var in ENV_TOKEN_HOSTS.get(hop_host, ()):
token = os.environ.get(var)
if token:
return RequestAuth(headers={"Authorization": f"Bearer {token}"})
return None return None

View File

@ -108,3 +108,59 @@ def test_resolver_never_crosses_host_boundary():
finally: finally:
await CREDENTIAL_STORE.delete(view.id) await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run()) asyncio.run(_run())
# ----- env-based HF token fallback -----
def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
async def _run():
# exact host over https -> env token attached
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer env_hf_token"
# non-https hop -> never attached
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
# CDN redirect host -> dropped (exact-host only)
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
asyncio.run(_run())
def test_env_token_secondary_var_is_honored(monkeypatch):
monkeypatch.delenv("HF_TOKEN", raising=False)
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token")
async def _run():
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer env_hub_token"
asyncio.run(_run())
def test_db_credential_takes_precedence_over_env(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
async def _run():
view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key")
try:
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
assert auth is not None
assert auth.headers["Authorization"] == "Bearer db_secret_key"
finally:
await CREDENTIAL_STORE.delete(view.id)
asyncio.run(_run())
def test_env_token_does_not_leak_into_explicit_path(monkeypatch):
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
async def _run():
# An explicit credential id that doesn't resolve must stay None; the env
# fallback only applies to the auto-resolve branch.
auth = await resolver.resolve_auth_for_hop(
"huggingface.co", "https", explicit_credential_id="does-not-exist"
)
assert auth is None
asyncio.run(_run())