mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add support for ENV based HF_TOKEN.
This commit is contained in:
parent
8a6e7906f7
commit
64c5853631
@ -17,6 +17,15 @@ AUTH_SCHEME_HEADER = "header"
|
|||||||
AUTH_SCHEME_QUERY = "query"
|
AUTH_SCHEME_QUERY = "query"
|
||||||
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
|
AUTH_SCHEMES = (AUTH_SCHEME_BEARER, AUTH_SCHEME_HEADER, AUTH_SCHEME_QUERY)
|
||||||
|
|
||||||
|
# Hosts for which a bearer token can be sourced from the environment when no
|
||||||
|
# stored credential matches. Values are the env var names to try, in order.
|
||||||
|
# Only consulted during auto-resolve for an exact host match over https, so the
|
||||||
|
# same per-hop boundary rules apply (e.g. the token is dropped on a redirect to
|
||||||
|
# a CDN host). Kept here so the host->env-var mapping lives in one place.
|
||||||
|
ENV_TOKEN_HOSTS = {
|
||||||
|
"huggingface.co": ("HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class DownloadStatus:
|
class DownloadStatus:
|
||||||
QUEUED = "queued"
|
QUEUED = "queued"
|
||||||
|
|||||||
@ -10,6 +10,7 @@ which is exactly what these hubs expect.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from urllib.parse import urlencode, urlsplit, urlunsplit
|
from urllib.parse import urlencode, urlsplit, urlunsplit
|
||||||
@ -18,6 +19,7 @@ from app.model_downloader.constants import (
|
|||||||
AUTH_SCHEME_BEARER,
|
AUTH_SCHEME_BEARER,
|
||||||
AUTH_SCHEME_HEADER,
|
AUTH_SCHEME_HEADER,
|
||||||
AUTH_SCHEME_QUERY,
|
AUTH_SCHEME_QUERY,
|
||||||
|
ENV_TOKEN_HOSTS,
|
||||||
)
|
)
|
||||||
from app.model_downloader.credentials.store import normalize_host
|
from app.model_downloader.credentials.store import normalize_host
|
||||||
from app.model_downloader.database import queries
|
from app.model_downloader.database import queries
|
||||||
@ -89,6 +91,14 @@ def _resolve_sync(
|
|||||||
for sub in queries.list_subdomain_credentials():
|
for sub in queries.list_subdomain_credentials():
|
||||||
if sub.enabled and _matches(sub, hop_host):
|
if sub.enabled and _matches(sub, hop_host):
|
||||||
return _build_auth(sub)
|
return _build_auth(sub)
|
||||||
|
|
||||||
|
# Env fallback: only for an exact host match, and only after the DB lookups
|
||||||
|
# miss, so a user-set credential always takes precedence. The token is never
|
||||||
|
# persisted; it is read fresh from the environment on each hop.
|
||||||
|
for var in ENV_TOKEN_HOSTS.get(hop_host, ()):
|
||||||
|
token = os.environ.get(var)
|
||||||
|
if token:
|
||||||
|
return RequestAuth(headers={"Authorization": f"Bearer {token}"})
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -108,3 +108,59 @@ def test_resolver_never_crosses_host_boundary():
|
|||||||
finally:
|
finally:
|
||||||
await CREDENTIAL_STORE.delete(view.id)
|
await CREDENTIAL_STORE.delete(view.id)
|
||||||
asyncio.run(_run())
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
# ----- env-based HF token fallback -----
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_token_fallback_attaches_when_no_db_credential(monkeypatch):
|
||||||
|
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
# exact host over https -> env token attached
|
||||||
|
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||||
|
assert auth is not None
|
||||||
|
assert auth.headers["Authorization"] == "Bearer env_hf_token"
|
||||||
|
# non-https hop -> never attached
|
||||||
|
assert await resolver.resolve_auth_for_hop("huggingface.co", "http") is None
|
||||||
|
# CDN redirect host -> dropped (exact-host only)
|
||||||
|
assert await resolver.resolve_auth_for_hop("cdn-lfs.huggingface.co", "https") is None
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_token_secondary_var_is_honored(monkeypatch):
|
||||||
|
monkeypatch.delenv("HF_TOKEN", raising=False)
|
||||||
|
monkeypatch.setenv("HUGGING_FACE_HUB_TOKEN", "env_hub_token")
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||||
|
assert auth is not None
|
||||||
|
assert auth.headers["Authorization"] == "Bearer env_hub_token"
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_credential_takes_precedence_over_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
view = await CREDENTIAL_STORE.upsert("huggingface.co", "db_secret_key")
|
||||||
|
try:
|
||||||
|
auth = await resolver.resolve_auth_for_hop("huggingface.co", "https")
|
||||||
|
assert auth is not None
|
||||||
|
assert auth.headers["Authorization"] == "Bearer db_secret_key"
|
||||||
|
finally:
|
||||||
|
await CREDENTIAL_STORE.delete(view.id)
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_env_token_does_not_leak_into_explicit_path(monkeypatch):
|
||||||
|
monkeypatch.setenv("HF_TOKEN", "env_hf_token")
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
# An explicit credential id that doesn't resolve must stay None; the env
|
||||||
|
# fallback only applies to the auto-resolve branch.
|
||||||
|
auth = await resolver.resolve_auth_for_hop(
|
||||||
|
"huggingface.co", "https", explicit_credential_id="does-not-exist"
|
||||||
|
)
|
||||||
|
assert auth is None
|
||||||
|
asyncio.run(_run())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user